File size: 2,340 Bytes
00db68b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import json
import os
import random
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data.dataset import Dataset
class CC15M(Dataset):
def __init__(
self,
json_path,
video_folder=None,
resolution=512,
enable_bucket=False,
):
print(f"loading annotations from {json_path} ...")
self.dataset = json.load(open(json_path, 'r'))
self.length = len(self.dataset)
print(f"data scale: {self.length}")
self.enable_bucket = enable_bucket
self.video_folder = video_folder
resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
self.pixel_transforms = transforms.Compose([
transforms.Resize(resolution[0]),
transforms.CenterCrop(resolution),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
def get_batch(self, idx):
video_dict = self.dataset[idx]
video_id, name = video_dict['file_path'], video_dict['text']
if self.video_folder is None:
video_dir = video_id
else:
video_dir = os.path.join(self.video_folder, video_id)
pixel_values = Image.open(video_dir).convert("RGB")
return pixel_values, name
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
pixel_values, name = self.get_batch(idx)
break
except Exception as e:
print(e)
idx = random.randint(0, self.length-1)
if not self.enable_bucket:
pixel_values = self.pixel_transforms(pixel_values)
else:
pixel_values = np.array(pixel_values)
sample = dict(pixel_values=pixel_values, text=name)
return sample
if __name__ == "__main__":
dataset = CC15M(
csv_path="/mnt_wg/zhoumo.xjq/CCUtils/cc15m_add_index.json",
resolution=512,
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
for idx, batch in enumerate(dataloader):
print(batch["pixel_values"].shape, len(batch["text"])) |