File size: 2,447 Bytes
e8e478e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import numpy as np
import torch.utils
import torch.utils.data
from torch.utils.data.sampler import WeightedRandomSampler
import torch.distributed as dist
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

from .datasets import RealFakeDataset

    

def get_bal_sampler(dataset):
    targets = []
    for d in dataset.datasets:
        targets.extend(d.targets)

    ratio = np.bincount(targets)
    w = 1. / torch.tensor(ratio, dtype=torch.float)
    sample_weights = w[targets]
    sampler = WeightedRandomSampler(weights=sample_weights,
                                    num_samples=len(sample_weights))
    return sampler


def create_train_val_dataloader(opt, clip_model, transform, k_split: float):
    shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False

    dataset = RealFakeDataset(opt, clip_model, transform)

    # ๅˆ’ๅˆ†่ฎญ็ปƒ้›†ๅ’Œ้ชŒ่ฏ้›†
    dataset_size = len(dataset)
    train_size = int(dataset_size * k_split)
    val_size = dataset_size - train_size

    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=16
                                            )
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=16
                                            )

    return train_loader, val_loader


def create_test_dataloader(opt, clip_model, transform):
    shuffle = not opt.serial_batches if (opt.isTrain and not opt.class_bal) else False

    dataset = RealFakeDataset(opt, clip_model, transform)

    sampler = get_bal_sampler(dataset) if opt.class_bal else None


    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=shuffle,
                                              sampler=sampler,
                                              num_workers=16
                                            )
    return data_loader