Jassk28 commited on
Commit
767e8bb
1 Parent(s): aac53ff

Upload real_n_fake_dataloader.py

Browse files
Files changed (1) hide show
  1. dataset/real_n_fake_dataloader.py +119 -0
dataset/real_n_fake_dataloader.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # We will use this file to create a dataloader for the real and fake dataset
2
+ import os
3
+ import json
4
+ import torch
5
+ from torchvision import transforms
6
+ from torch.utils.data import DataLoader, Dataset
7
+ from PIL import Image
8
+ import numpy as np
9
+ import pandas as pd
10
+ import cv2
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import pywt
16
+
17
+ class Extracted_Frames_Dataset(Dataset):
18
+ def __init__(self, root_dir, split = "train", transform = None, extend = 'None', multi_modal = "dct"):
19
+ """
20
+ Args:
21
+ returns:
22
+ """
23
+ AssertionError(split in ["train", "val", "test"]), "Split must be one of (train, val, test)"
24
+ self.multi_modal = multi_modal
25
+ self.root_dir = root_dir
26
+ self.split = split
27
+ self.transform = transform
28
+ if extend == 'faceswap':
29
+ self.dataset = pd.read_csv(os.path.join(root_dir, f"faceswap_extended_{self.split}.csv"))
30
+ elif extend == 'fsgan':
31
+ self.dataset = pd.read_csv(os.path.join(root_dir, f"fsgan_extended_{self.split}.csv"))
32
+ else:
33
+ self.dataset = pd.read_csv(os.path.join(root_dir, f"{self.split}.csv"))
34
+
35
+
36
+ def __len__(self):
37
+ return len(self.dataset)
38
+
39
+ def __getitem__(self, idx):
40
+ sample_input = self.get_sample_input(idx)
41
+ return sample_input
42
+
43
+
44
+ def get_sample_input(self, idx):
45
+ rgb_image = self.get_rgb_image(idx)
46
+ label = self.get_label(idx)
47
+ if self.multi_modal == "dct":
48
+ dct_image = self.get_dct_image(idx)
49
+ sample_input = {"rgb_image": rgb_image, "dct_image": dct_image, "label": label}
50
+
51
+ # dct_image = self.get_dct_image(idx)
52
+ elif self.multi_modal == "fft":
53
+ fft_image = self.get_fft_image(idx)
54
+ sample_input = {"rgb_image": rgb_image, "dct_image": fft_image, "label": label}
55
+ elif self.multi_modal == "hh":
56
+ hh_image = self.get_hh_image(idx)
57
+ sample_input = {"rgb_image": rgb_image, "dct_image": hh_image, "label": label}
58
+ else:
59
+ AssertionError("multi_modal must be one of (dct:discrete cosine transform, fft: fast forier transform, hh)")
60
+
61
+ return sample_input
62
+
63
+
64
+ def get_fft_image(self, idx):
65
+ gray_image_path = self.dataset.iloc[idx, 0]
66
+ gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
67
+ fft_image = self.compute_fft(gray_image)
68
+ if self.transform:
69
+ fft_image = self.transform(fft_image)
70
+
71
+ return fft_image
72
+
73
+
74
+ def compute_fft(self, image):
75
+ f = np.fft.fft2(image)
76
+ fshift = np.fft.fftshift(f)
77
+ magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
78
+ return magnitude_spectrum
79
+
80
+
81
+ def get_hh_image(self, idx):
82
+ gray_image_path = self.dataset.iloc[idx, 0]
83
+ gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
84
+ hh_image = self.compute_hh(gray_image)
85
+ if self.transform:
86
+ hh_image = self.transform(hh_image)
87
+ return hh_image
88
+
89
+ def compute_hh(self, image):
90
+ coeffs2 = pywt.dwt2(image, 'haar')
91
+ LL, (LH, HL, HH) = coeffs2
92
+ return HH
93
+
94
+ def get_rgb_image(self, idx):
95
+ rgb_image_path = self.dataset.iloc[idx, 0]
96
+ rgb_image = Image.open(rgb_image_path)
97
+ if self.transform:
98
+ rgb_image = self.transform(rgb_image)
99
+ return rgb_image
100
+
101
+ def get_dct_image(self, idx):
102
+ rgb_image_path = self.dataset.iloc[idx, 0]
103
+ rgb_image = cv2.imread(rgb_image_path)
104
+ dct_image = self.compute_dct_color(rgb_image)
105
+ if self.transform:
106
+ dct_image = self.transform(dct_image)
107
+
108
+ return dct_image
109
+
110
+ def get_label(self, idx):
111
+ return self.dataset.iloc[idx, 1]
112
+
113
+
114
+ def compute_dct_color(self, image):
115
+ image_float = np.float32(image)
116
+ dct_image = np.zeros_like(image_float)
117
+ for i in range(3):
118
+ dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
119
+ return dct_image