Jassk28 commited on
Commit
50253c6
·
1 Parent(s): 00030ca

Upload test_image_fusion.py

Browse files
Files changed (1) hide show
  1. test_image_fusion.py +182 -0
test_image_fusion.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader
6
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
7
+ from torch.optim.lr_scheduler import CosineAnnealingLR
8
+ from tqdm import tqdm
9
+ import warnings
10
+ warnings.filterwarnings("ignore")
11
+ import cv2
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ import pywt
15
+
16
+ from utils.config import cfg
17
+ from dataset.real_n_fake_dataloader import Extracted_Frames_Dataset
18
+ from utils.data_transforms import get_transforms_train, get_transforms_val
19
+ from net.Multimodalmodel import Image_n_DCT
20
+
21
+
22
+
23
+ import os
24
+ import json
25
+ import torch
26
+ from torchvision import transforms
27
+ from torch.utils.data import DataLoader, Dataset
28
+ from PIL import Image
29
+ import numpy as np
30
+ import pandas as pd
31
+ import cv2
32
+ import argparse
33
+
34
+ class Test_Dataset(Dataset):
35
+ def __init__(self, test_data_path = None, transform = None, image_path = None, multi_modal = "dct"):
36
+ """
37
+ Args:
38
+ returns:
39
+ """
40
+ self.multi_modal = multi_modal
41
+ if test_data_path is None and image_path is not None:
42
+ self.dataset = [[image_path, 2]]
43
+ self.transform = transform
44
+
45
+ else:
46
+ self.transform = transform
47
+
48
+ self.real_data = os.listdir(test_data_path + "/real")
49
+ self.fake_data = os.listdir(test_data_path + "/fake")
50
+ self.dataset = []
51
+ for image in self.real_data:
52
+ self.dataset.append([test_data_path + "/real/" + image, 1])
53
+
54
+ for image in self.fake_data:
55
+ self.dataset.append([test_data_path + "/fake/" + image, 0])
56
+
57
+ def __len__(self):
58
+ return len(self.dataset)
59
+
60
+ def __getitem__(self, idx):
61
+ sample_input = self.get_sample_input(idx)
62
+ return sample_input
63
+
64
+ def get_sample_input(self, idx):
65
+ rgb_image = self.get_rgb_image(idx)
66
+ label = self.get_label(idx)
67
+ if self.multi_modal == "dct":
68
+ dct_image = self.get_dct_image(idx)
69
+ sample_input = {"rgb_image": rgb_image, "dct_image": dct_image, "label": label}
70
+
71
+ # dct_image = self.get_dct_image(idx)
72
+ elif self.multi_modal == "fft":
73
+ fft_image = self.get_fft_image(idx)
74
+ sample_input = {"rgb_image": rgb_image, "dct_image": fft_image, "label": label}
75
+ elif self.multi_modal == "hh":
76
+ hh_image = self.get_hh_image(idx)
77
+ sample_input = {"rgb_image": rgb_image, "dct_image": hh_image, "label": label}
78
+ else:
79
+ AssertionError("multi_modal must be one of (dct:discrete cosine transform, fft: fast forier transform, hh)")
80
+
81
+ return sample_input
82
+
83
+
84
+ def get_fft_image(self, idx):
85
+ gray_image_path = self.dataset[idx][0]
86
+ gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
87
+ fft_image = self.compute_fft(gray_image)
88
+ if self.transform:
89
+ fft_image = self.transform(fft_image)
90
+
91
+ return fft_image
92
+
93
+
94
+ def compute_fft(self, image):
95
+ f = np.fft.fft2(image)
96
+ fshift = np.fft.fftshift(f)
97
+ magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1) # Add 1 to avoid log(0)
98
+ return magnitude_spectrum
99
+
100
+
101
+ def get_hh_image(self, idx):
102
+ gray_image_path = self.dataset[idx][0]
103
+ gray_image = cv2.imread(gray_image_path, cv2.IMREAD_GRAYSCALE)
104
+ hh_image = self.compute_hh(gray_image)
105
+ if self.transform:
106
+ hh_image = self.transform(hh_image)
107
+ return hh_image
108
+
109
+ def compute_hh(self, image):
110
+ coeffs2 = pywt.dwt2(image, 'haar')
111
+ LL, (LH, HL, HH) = coeffs2
112
+ return HH
113
+
114
+ def get_rgb_image(self, idx):
115
+ rgb_image_path = self.dataset[idx][0]
116
+ rgb_image = Image.open(rgb_image_path)
117
+ if self.transform:
118
+ rgb_image = self.transform(rgb_image)
119
+ return rgb_image
120
+
121
+ def get_dct_image(self, idx):
122
+ rgb_image_path = self.dataset[idx][0]
123
+ rgb_image = cv2.imread(rgb_image_path)
124
+ dct_image = self.compute_dct_color(rgb_image)
125
+ if self.transform:
126
+ dct_image = self.transform(dct_image)
127
+
128
+ return dct_image
129
+
130
+ def get_label(self, idx):
131
+ return self.dataset[idx][1]
132
+
133
+
134
+ def compute_dct_color(self, image):
135
+ image_float = np.float32(image)
136
+ dct_image = np.zeros_like(image_float)
137
+ for i in range(3):
138
+ dct_image[:, :, i] = cv2.dct(image_float[:, :, i])
139
+ return dct_image
140
+
141
+
142
+ class Test:
143
+ def __init__(self, model_paths = [ 'weights/faceswap-hh-best_model.pth',
144
+ 'weights/faceswap-fft-best_model.pth',
145
+ ],
146
+ multi_modal = ["hh","fct"]):
147
+ self.model_path = model_paths
148
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
149
+ print(self.device)
150
+ # Load the model
151
+ self.model1 = Image_n_DCT()
152
+ self.model1.load_state_dict(torch.load(self.model_path[0], map_location = self.device))
153
+ self.model1.to(self.device)
154
+ self.model1.eval()
155
+
156
+ self.model2 = Image_n_DCT()
157
+ self.model2.load_state_dict(torch.load(self.model_path[1], map_location = self.device))
158
+ self.model2.to(self.device)
159
+ self.model2.eval()
160
+
161
+
162
+ self.multi_modal = multi_modal
163
+
164
+
165
+ def testimage(self, image_path):
166
+ test_dataset1 = Test_Dataset(transform = get_transforms_val(), image_path = image_path, multi_modal = self.multi_modal[0])
167
+ test_dataset2 = Test_Dataset(transform = get_transforms_val(), image_path = image_path, multi_modal = self.multi_modal[1])
168
+
169
+ inputs1 = test_dataset1[0]
170
+ rgb_image1, dct_image1 = inputs1['rgb_image'].to(self.device), inputs1['dct_image'].to(self.device)
171
+
172
+ inputs2 = test_dataset2[0]
173
+ rgb_image2, dct_image2 = inputs2['rgb_image'].to(self.device), inputs2['dct_image'].to(self.device)
174
+
175
+ output1 = self.model1(rgb_image1.unsqueeze(0), dct_image1.unsqueeze(0))
176
+
177
+ output2 = self.model2(rgb_image2.unsqueeze(0), dct_image2.unsqueeze(0))
178
+
179
+ output = (output1 + output2)/2
180
+ # print(output.shape)
181
+ _, predicted = torch.max(output.data, 1)
182
+ return 'real' if predicted==1 else 'fake'