|
import torch |
|
from datasets import load_dataset |
|
from torch.utils.data import DataLoader |
|
from transformers import BertTokenizer |
|
import decord |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
|
|
FRAMES = 400 |
|
H, W = 780, 780 |
|
BATCH_SIZE = 8 |
|
TEXT_MAX_LEN = 32 |
|
|
|
|
|
dataset = load_dataset("minh132/pexels-videos", split="train") |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
|
|
class VideoDataset(torch.utils.data.Dataset): |
|
def __init__(self, dataset): |
|
self.dataset = dataset |
|
self.decord_ctx = decord.cpu(0) |
|
|
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
item = self.dataset[idx] |
|
|
|
|
|
vr = decord.VideoReader(item["video_path"], ctx=self.decord_ctx) |
|
frame_indices = np.linspace(0, len(vr)-1, FRAMES, dtype=int) |
|
video = vr.get_batch(frame_indices).numpy() |
|
video = torch.from_numpy(video).permute(3, 0, 1, 2).float() |
|
|
|
|
|
video = F.interpolate(video, size=(H, W), mode="bilinear") |
|
video = (video / 255.0) * 2 - 1 |
|
|
|
|
|
text = tokenizer( |
|
item["caption"], |
|
padding="max_length", |
|
truncation=True, |
|
max_length=TEXT_MAX_LEN, |
|
return_tensors="pt" |
|
).input_ids.squeeze(0) |
|
|
|
return {"video": video, "text": text} |
|
|
|
|
|
dataset = VideoDataset(dataset) |
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) |