|
import json |
|
import cv2 |
|
import numpy as np |
|
import os |
|
import random |
|
from glob import glob |
|
|
|
from torch.utils.data import Dataset |
|
from torchvision import transforms |
|
from PIL import Image, ImageDraw |
|
import torch |
|
import albumentations as A |
|
|
|
|
|
class ZalandoDataset(Dataset): |
|
def __init__(self, transform, root="/tmp/zalando/train/", width = 512, height = 512): |
|
self.root = root |
|
self.transform = transform |
|
self.width = width |
|
self.height = height |
|
self.image_paths = sorted(glob(f'{self.root}image/*.jpg')) |
|
self.ref_paths = sorted(glob(f'{self.root}cloth/*.jpg')) |
|
self.parse_paths = sorted(glob(f"{self.root}image-parse-v3/*.png")) |
|
self.prompts = ["", "a professional, detailed, high-quality image", "shirt"] |
|
self.labels = { |
|
0: ['background', [0, 10]], |
|
1: ['hair', [1, 2]], |
|
2: ['face', [4, 13]], |
|
3: ['upper', [5, 6, 7]], |
|
4: ['bottom', [9, 12]], |
|
5: ['left_arm', [14]], |
|
6: ['right_arm', [15]], |
|
7: ['left_leg', [16]], |
|
8: ['right_leg', [17]], |
|
9: ['left_shoe', [18]], |
|
10: ['right_shoe', [19]], |
|
11: ['socks', [8]], |
|
12: ['noise', [3, 11]] |
|
} |
|
self.random_trans=A.Compose([ |
|
A.HorizontalFlip(p=0.5), |
|
A.Rotate(limit=20), |
|
A.Blur(p=0.3), |
|
|
|
]) |
|
|
|
|
|
def img_segment(self,parse_img,wanted_label = 3): |
|
im_parse_pil = transforms.Resize((512,512), interpolation=0)(parse_img) |
|
parse = torch.from_numpy(np.array(im_parse_pil)[None]).long() |
|
parse_map = torch.FloatTensor(20, 512, 512).zero_() |
|
parse_map = parse_map.scatter_(0, parse, 1.0) |
|
new_parse_map = torch.FloatTensor(13, 512, 512).zero_() |
|
for i in range(len(self.labels)): |
|
for label in self.labels[i][1]: |
|
new_parse_map[i] += parse_map[label] |
|
|
|
shirt_mask = new_parse_map[wanted_label].numpy() |
|
return shirt_mask.astype(dtype="uint8") * 255 |
|
|
|
def add_noise(self, image): |
|
image = image.astype(np.uint8) |
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
|
|
|
|
contours, _ = cv2.findContours(gray, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
|
|
|
if contours: |
|
random_contour = contours[np.random.randint(len(contours))] |
|
|
|
|
|
canvas = np.zeros_like(gray) |
|
|
|
|
|
cv2.drawContours(canvas, [random_contour], 0, 255, thickness=10) |
|
|
|
|
|
kernel = np.ones((15,15), np.uint8) |
|
canvas = cv2.dilate(canvas, kernel, iterations=1) |
|
|
|
|
|
boundary = cv2.absdiff(canvas, gray) |
|
|
|
|
|
points_on_boundary = [] |
|
for i in range(len(random_contour)): |
|
x, y = random_contour[i][0] |
|
points_on_boundary.append((x, y)) |
|
points_on_boundary = np.array(points_on_boundary) |
|
|
|
|
|
for point in points_on_boundary: |
|
|
|
thickness = 30 |
|
|
|
length = 0.1 |
|
angle = np.random.randint(0,360) |
|
endpoint = (int(point[0] + length * np.cos(angle * np.pi / 180)), |
|
int(point[1] + length * np.sin(angle * np.pi / 180))) |
|
cv2.line(boundary, tuple(point), endpoint, 255, thickness) |
|
|
|
|
|
image = cv2.bitwise_or(image, cv2.cvtColor(boundary, cv2.COLOR_GRAY2BGR)) |
|
return image |
|
|
|
def __len__(self): |
|
return len(self.image_paths) |
|
|
|
def __getitem__(self, idx): |
|
source_filename = self.ref_paths[idx] |
|
target_filename = self.image_paths[idx] |
|
parse_filename = self.parse_paths[idx] |
|
|
|
prompt = random.choice(self.prompts) |
|
|
|
source = cv2.imread(source_filename) |
|
source = cv2.resize(source, (224,224)) |
|
if self.transform: |
|
source = self.random_trans(image=source)["image"] |
|
|
|
|
|
target = cv2.imread(target_filename) |
|
target = cv2.resize(target, (self.width,self.height)) |
|
|
|
parse = Image.open(parse_filename).resize((self.width,self.height)) |
|
mask = self.img_segment(parse,3) |
|
|
|
|
|
mask = np.array(mask) |
|
mask = cv2.cvtColor(mask, cv2.COLOR_RGB2BGR) |
|
|
|
mask = self.add_noise(mask) |
|
mask_gray = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) |
|
mask_gray = np.expand_dims(mask_gray, axis=-1) |
|
|
|
|
|
source = cv2.cvtColor(source, cv2.COLOR_BGR2RGB) |
|
target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
|
mask = mask.astype(np.float32) / 255.0 |
|
source = source.astype(np.float32) / 255.0 |
|
target0 = target.astype(np.float32) / 255.0 |
|
masked_image = target0 * (mask < 0.5) |
|
|
|
|
|
target_normalized = (target.astype(np.float32) / 127.5) - 1.0 |
|
|
|
|
|
return dict(jpg=target_normalized, txt=prompt, hint=source, mask = mask_gray, masked_image = masked_image, path=source_filename) |
|
|