Spaces:
Configuration error
Configuration error
| import os | |
| from os import path | |
| from torch.utils.data.dataset import Dataset | |
| from torchvision import transforms | |
| from torchvision.transforms import InterpolationMode | |
| import torch.nn.functional as Ff | |
| from PIL import Image | |
| import numpy as np | |
| from dataset.range_transform import im_normalization, im_rgb2lab_normalization, ToTensor, RGB2Lab | |
| class VideoReader_221128_TransColorization(Dataset): | |
| """ | |
| This class is used to read a video, one frame at a time | |
| """ | |
| def __init__(self, vid_name, image_dir, mask_dir, size=-1, to_save=None, use_all_mask=False, size_dir=None, args=None): | |
| """ | |
| image_dir - points to a directory of jpg images | |
| mask_dir - points to a directory of png masks | |
| size - resize min. side to size. Does nothing if <0. | |
| to_save - optionally contains a list of file names without extensions | |
| where the segmentation mask is required | |
| use_all_mask - when true, read all available mask in mask_dir. | |
| Default false. Set to true for YouTubeVOS validation. | |
| """ | |
| self.vid_name = vid_name | |
| self.image_dir = image_dir | |
| self.mask_dir = mask_dir | |
| self.to_save = to_save | |
| self.use_all_mask = use_all_mask | |
| # print('use_all_mask', use_all_mask);assert 1==0 | |
| if size_dir is None: | |
| self.size_dir = self.image_dir | |
| else: | |
| self.size_dir = size_dir | |
| # flag_reverse = args.getattr('reverse', False) if args is not None else False | |
| flag_reverse = False | |
| self.frames = [img for img in sorted(os.listdir(self.image_dir), reverse=flag_reverse) if (img.endswith('.jpg') or img.endswith('.png')) and not img.startswith('.')] | |
| self.palette = Image.open(path.join(mask_dir, sorted([msk for msk in os.listdir(mask_dir) if not msk.startswith('.')])[0])).getpalette() | |
| self.first_gt_path = path.join(self.mask_dir, sorted([msk for msk in os.listdir(self.mask_dir) if not msk.startswith('.')])[0]) | |
| self.suffix = self.first_gt_path.split('.')[-1] | |
| if size < 0: | |
| self.im_transform = transforms.Compose([ | |
| RGB2Lab(), | |
| ToTensor(), | |
| im_rgb2lab_normalization, | |
| ]) | |
| else: | |
| self.im_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| im_normalization, | |
| transforms.Resize(size, interpolation=InterpolationMode.BILINEAR), | |
| ]) | |
| self.size = size | |
| def __getitem__(self, idx): | |
| frame = self.frames[idx] | |
| info = {} | |
| data = {} | |
| info['frame'] = frame | |
| info['vid_name'] = self.vid_name | |
| info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save) | |
| im_path = path.join(self.image_dir, frame) | |
| img = Image.open(im_path).convert('RGB') | |
| if self.image_dir == self.size_dir: | |
| shape = np.array(img).shape[:2] | |
| else: | |
| size_path = path.join(self.size_dir, frame) | |
| size_im = Image.open(size_path).convert('RGB') | |
| shape = np.array(size_im).shape[:2] | |
| gt_path = path.join(self.mask_dir, sorted(os.listdir(self.mask_dir))[idx]) if idx < len(os.listdir(self.mask_dir)) else None | |
| img = self.im_transform(img) | |
| img_l = img[:1,:,:] | |
| img_lll = img_l.repeat(3,1,1) | |
| load_mask = self.use_all_mask or (gt_path == self.first_gt_path) | |
| if load_mask and path.exists(gt_path): | |
| mask = Image.open(gt_path).convert('RGB') | |
| # 用 PIL 先 resize 成和 img 尺寸一致 | |
| mask = mask.resize((img.shape[2], img.shape[1]), Image.BILINEAR) | |
| mask = self.im_transform(mask) | |
| # keep L channel of reference image in case First frame is not exemplar | |
| # mask_ab = mask[1:3,:,:] | |
| # data['mask'] = mask_ab | |
| data['mask'] = mask | |
| info['shape'] = shape | |
| info['need_resize'] = not (self.size < 0) | |
| data['rgb'] = img_lll | |
| data['info'] = info | |
| return data | |
| def resize_mask(self, mask): | |
| # mask transform is applied AFTER mapper, so we need to post-process it in eval.py | |
| h, w = mask.shape[-2:] | |
| min_hw = min(h, w) | |
| return Ff.interpolate(mask, (int(h/min_hw*self.size), int(w/min_hw*self.size)), | |
| mode='nearest') | |
| def get_palette(self): | |
| return self.palette | |
| def __len__(self): | |
| return len(self.frames) | |