File size: 1,498 Bytes
bd65e34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from torchvision.transforms import Compose
import torch
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
import polars as pl


class FrameDataset(Dataset):
    def __init__(
        self,
        df: pl.DataFrame,
        augments: Compose,
        frames_per_clip: int,
        stride: int | None = None,
        is_train: bool = True,
    ):
        super().__init__()
        self.paths = df["path"].to_list()
        self.is_train = is_train
        if is_train:
            self.y = torch.tensor(df["label"])
        self.frames_per_clip = frames_per_clip
        self.augments = augments
        self.stride = stride or frames_per_clip

    def __len__(self):
        return len(self.paths) // self.stride

    def __getitem__(self, idx):
        start = idx * self.stride
        stop = start + self.frames_per_clip
        if stop - start <= 1:
            path = self.paths[start]
            frames_tr = self._open_augment_img(path)
            if self.is_train:
                y = self.y[start]
        else:
            frames = [self._open_augment_img(path) for path in self.paths[start:stop]]
            frames_tr = torch.stack(frames)
            if self.is_train:
                y = self.y[start:stop].max()
        if self.is_train:
            return frames_tr, y
        else:
            return frames_tr

    def _open_augment_img(self, path):
        img = default_loader(path)
        img = self.augments(img)
        return img