primepake
add training flowvae
4f877a2
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, IterableDataset
from torchvision import transforms
import datasets
from datasets import register
from utils.geometry import make_coord_scale_grid
from models.ldm.dac.audiotools import AudioSignal
import numpy as np
from models.ldm.dac.audiotools.data.datasets import AudioDataset, AudioLoader
from models.ldm.dac.audiotools import transforms as tfm
class BaseWrapperCAE:
def __init__(
self,
dataset,
resize_inp,
return_gt=True,
gt_glores_lb=None,
gt_glores_ub=None,
gt_patch_size=None,
p_whole=0.0,
p_max=0.0
):
self.dataset = datasets.make(dataset)
self.resize_inp = resize_inp
self.return_gt = return_gt
self.gt_glores_lb = gt_glores_lb
self.gt_glores_ub = gt_glores_ub
self.gt_patch_size = gt_patch_size
self.p_whole = p_whole
self.p_max = p_max
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.5, 0.5),
])
def process(self, image):
assert image.size[0] == image.size[1]
ret = {}
inp = image.resize((self.resize_inp, self.resize_inp), Image.LANCZOS)
inp = self.transform(inp)
ret.update({'inp': inp})
if not self.return_gt:
return ret
if self.gt_glores_lb is None:
glo = self.transform(image)
else:
if random.random() < self.p_whole:
r = self.gt_patch_size
elif random.random() < self.p_max:
r = min(image.size[0], self.gt_glores_ub)
else:
r = random.randint(
self.gt_glores_lb,
max(self.gt_glores_lb, min(image.size[0], self.gt_glores_ub))
)
glo = image.resize((r, r), Image.LANCZOS)
glo = self.transform(glo)
p = self.gt_patch_size
ii = random.randint(0, glo.shape[1] - p)
jj = random.randint(0, glo.shape[2] - p)
gt_patch = glo[:, ii: ii + p, jj: jj + p]
x0, y0 = ii / glo.shape[-2], jj / glo.shape[-1]
x1, y1 = (ii + p) / glo.shape[-2], (jj + p) / glo.shape[-1]
coord, scale = make_coord_scale_grid((p, p), range=[[x0, x1], [y0, y1]])
ret['gt'] = torch.cat([
gt_patch, # 3 p p
coord.permute(2, 0, 1), # 2 p p
scale.permute(2, 0, 1), # 2 p p
], dim=0)
return ret
@register('wrapper_cae')
class WrapperCAE(BaseWrapperCAE, Dataset):
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
data = self.dataset[idx]
if isinstance(data, dict):
ret = dict()
ret.update(self.process(data.pop('image')))
ret.update(data)
return ret
else:
return self.process(data)
@register('wrapper_cae_iterable')
class WrapperCAE(BaseWrapperCAE, IterableDataset):
def __iter__(self):
for data in self.dataset:
if isinstance(data, dict):
ret = dict()
ret.update(self.process(data.pop('image')))
ret.update(data)
yield ret
else:
yield self.process(data)