Manga / dataset /datasets.py
Keiser41's picture
Upload 47 files
62456b0
raw
history blame contribute delete
No virus
3.31 kB
from PIL import Image
import torch
import os
import numpy as np
import torchvision.transforms as transforms
from utils.utils import generate_mask
class TrainDataset(torch.utils.data.Dataset):
def __init__(self, data_path, transform=None):
self.data = os.listdir(os.path.join(data_path, 'color'))
self.data_path = data_path
self.transform = transform
self.ToTensor = transforms.ToTensor()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image_name = self.data[idx]
color_img = Image.open(os.path.join(self.data_path, 'color', image_name)).convert('RGB')
bw_name = self.data[idx]
dfm_name = 'dfm_' + self.data[idx]
bw_img = Image.open(os.path.join(self.data_path, 'bw', bw_name)).convert('L')
dfm_img = Image.open(os.path.join(self.data_path, 'bw', dfm_name)).convert('L')
color_img = np.array(color_img)
bw_img = np.array(bw_img)
dfm_img = np.array(dfm_img)
bw_img = np.expand_dims(bw_img, 2)
dfm_img = np.expand_dims(dfm_img, 2)
bw_img = np.concatenate([bw_img, dfm_img], axis=2)
if self.transform:
result = self.transform(image=color_img, mask=bw_img)
color_img = result['image']
bw_img = result['mask']
color_img = self.ToTensor(color_img)
bw_img = self.ToTensor(bw_img)
color_img = (color_img - 0.5) / 0.5 # Normalizaci贸n de color_img
mask = generate_mask(bw_img.shape[1], bw_img.shape[2])
hint = torch.cat((color_img * mask, mask), 0)
return bw_img, bw_img, color_img, hint
class FineTuningDataset(torch.utils.data.Dataset):
def __init__(self, data_path, transform=None, mult_amount=1):
self.data = [x for x in os.listdir(os.path.join(data_path, 'real_manga')) if x.find('_dfm') == -1]
self.color_data = [x for x in os.listdir(os.path.join(data_path, 'color'))]
self.data_path = data_path
self.transform = transform
self.mults_amount = mult_amount
np.random.shuffle(self.color_data)
self.ToTensor = transforms.ToTensor()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image_name = self.data[idx]
color_img = Image.open(os.path.join(self.data_path, 'color', image_name)).convert('RGB')
bw_name = self.data[idx]
dfm_name = 'dfm_' + self.data[idx]
bw_img = Image.open(os.path.join(self.data_path, 'bw', bw_name)).convert('L')
dfm_img = Image.open(os.path.join(self.data_path, 'bw', dfm_name)).convert('L')
color_img = np.array(color_img)
bw_img = np.array(bw_img)
dfm_img = np.array(dfm_img)
bw_img = np.expand_dims(bw_img, 2)
dfm_img = np.expand_dims(dfm_img, 2)
bw_img = np.concatenate([bw_img, dfm_img], axis=2)
if self.transform:
result = self.transform(image=color_img, mask=bw_img)
color_img = result['image']
bw_img = result['mask']
color_img = self.ToTensor(color_img)
bw_img = self.ToTensor(bw_img)
color_img = (color_img - 0.5) / 0.5 # Normalizaci贸n de color_img
return bw_img, color_img # Devuelve bw_img una vez y color_img