Spaces:
Sleeping
Sleeping
| import random | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset, IterableDataset | |
| from torchvision import transforms | |
| import datasets | |
| from datasets import register | |
| from utils.geometry import make_coord_scale_grid | |
| from models.ldm.dac.audiotools import AudioSignal | |
| import numpy as np | |
| from models.ldm.dac.audiotools.data.datasets import AudioDataset, AudioLoader | |
| from models.ldm.dac.audiotools import transforms as tfm | |
| class BaseWrapperCAE: | |
| def __init__( | |
| self, | |
| dataset, | |
| resize_inp, | |
| return_gt=True, | |
| gt_glores_lb=None, | |
| gt_glores_ub=None, | |
| gt_patch_size=None, | |
| p_whole=0.0, | |
| p_max=0.0 | |
| ): | |
| self.dataset = datasets.make(dataset) | |
| self.resize_inp = resize_inp | |
| self.return_gt = return_gt | |
| self.gt_glores_lb = gt_glores_lb | |
| self.gt_glores_ub = gt_glores_ub | |
| self.gt_patch_size = gt_patch_size | |
| self.p_whole = p_whole | |
| self.p_max = p_max | |
| self.transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(0.5, 0.5), | |
| ]) | |
| def process(self, image): | |
| assert image.size[0] == image.size[1] | |
| ret = {} | |
| inp = image.resize((self.resize_inp, self.resize_inp), Image.LANCZOS) | |
| inp = self.transform(inp) | |
| ret.update({'inp': inp}) | |
| if not self.return_gt: | |
| return ret | |
| if self.gt_glores_lb is None: | |
| glo = self.transform(image) | |
| else: | |
| if random.random() < self.p_whole: | |
| r = self.gt_patch_size | |
| elif random.random() < self.p_max: | |
| r = min(image.size[0], self.gt_glores_ub) | |
| else: | |
| r = random.randint( | |
| self.gt_glores_lb, | |
| max(self.gt_glores_lb, min(image.size[0], self.gt_glores_ub)) | |
| ) | |
| glo = image.resize((r, r), Image.LANCZOS) | |
| glo = self.transform(glo) | |
| p = self.gt_patch_size | |
| ii = random.randint(0, glo.shape[1] - p) | |
| jj = random.randint(0, glo.shape[2] - p) | |
| gt_patch = glo[:, ii: ii + p, jj: jj + p] | |
| x0, y0 = ii / glo.shape[-2], jj / glo.shape[-1] | |
| x1, y1 = (ii + p) / glo.shape[-2], (jj + p) / glo.shape[-1] | |
| coord, scale = make_coord_scale_grid((p, p), range=[[x0, x1], [y0, y1]]) | |
| ret['gt'] = torch.cat([ | |
| gt_patch, # 3 p p | |
| coord.permute(2, 0, 1), # 2 p p | |
| scale.permute(2, 0, 1), # 2 p p | |
| ], dim=0) | |
| return ret | |
| class WrapperCAE(BaseWrapperCAE, Dataset): | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| data = self.dataset[idx] | |
| if isinstance(data, dict): | |
| ret = dict() | |
| ret.update(self.process(data.pop('image'))) | |
| ret.update(data) | |
| return ret | |
| else: | |
| return self.process(data) | |
| class WrapperCAE(BaseWrapperCAE, IterableDataset): | |
| def __iter__(self): | |
| for data in self.dataset: | |
| if isinstance(data, dict): | |
| ret = dict() | |
| ret.update(self.process(data.pop('image'))) | |
| ret.update(data) | |
| yield ret | |
| else: | |
| yield self.process(data) |