HyperReel-V1-B / dataloader.py
SkillForge45's picture
Create dataloader.py
9745119 verified
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import BertTokenizer
import decord
import numpy as np
from tqdm import tqdm
FRAMES = 400
H, W = 780, 780
BATCH_SIZE = 8
TEXT_MAX_LEN = 32
dataset = load_dataset("minh132/pexels-videos", split="train")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
class VideoDataset(torch.utils.data.Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.decord_ctx = decord.cpu(0) # CPU decoding
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
vr = decord.VideoReader(item["video_path"], ctx=self.decord_ctx)
frame_indices = np.linspace(0, len(vr)-1, FRAMES, dtype=int)
video = vr.get_batch(frame_indices).numpy() # (FRAMES, H, W, 3)
video = torch.from_numpy(video).permute(3, 0, 1, 2).float() # (3, FRAMES, H, W)
video = F.interpolate(video, size=(H, W), mode="bilinear")
video = (video / 255.0) * 2 - 1 # [-1, 1]
text = tokenizer(
item["caption"],
padding="max_length",
truncation=True,
max_length=TEXT_MAX_LEN,
return_tensors="pt"
).input_ids.squeeze(0)
return {"video": video, "text": text}
dataset = VideoDataset(dataset)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)