imdaisylee
commited on
Commit
•
e5e0406
1
Parent(s):
e68f790
Upload deepfake_model.py
Browse files- deepfake_model.py +100 -0
deepfake_model.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Imports
|
2 |
+
import sys
|
3 |
+
import tensorflow as tf
|
4 |
+
import cv2
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
|
10 |
+
batch_size = 32
|
11 |
+
num_epochs = 4
|
12 |
+
|
13 |
+
# Load the dataset
|
14 |
+
data_dir = '/Users/daisy./Documents/data'
|
15 |
+
dataset = tf.keras.preprocessing.image_dataset_from_directory(data_dir)
|
16 |
+
|
17 |
+
# Split the dataset into train, validation, and test sets
|
18 |
+
train_size = int(len(dataset) * 0.7)
|
19 |
+
val_size = int(len(dataset) * 0.2)
|
20 |
+
test_size = int(len(dataset) * 0.1)
|
21 |
+
|
22 |
+
train = dataset.take(train_size)
|
23 |
+
val = dataset.skip(train_size).take(val_size)
|
24 |
+
test = dataset.skip(train_size + val_size).take(test_size)
|
25 |
+
|
26 |
+
# Preprocess the data for the Xception model
|
27 |
+
tf.keras.backend.clear_session()
|
28 |
+
|
29 |
+
preprocess = tf.keras.applications.xception.preprocess_input
|
30 |
+
|
31 |
+
# Apply preprocessing before batching
|
32 |
+
train = train.map(lambda x, y: (preprocess(x), y))
|
33 |
+
val = val.map(lambda x, y: (preprocess(x), y))
|
34 |
+
test = test.map(lambda x, y: (preprocess(x), y))
|
35 |
+
|
36 |
+
# Shuffle and batch the datasets, leaving the last batch incomplete
|
37 |
+
train_set = train.shuffle(buffer_size=len(train)).unbatch().batch(batch_size).repeat(num_epochs).prefetch(1)
|
38 |
+
valid_set = val.unbatch().batch(batch_size).repeat(num_epochs)
|
39 |
+
test_set = test.unbatch().batch(batch_size).repeat(num_epochs)
|
40 |
+
|
41 |
+
# Check the element specifications
|
42 |
+
# print("Train set element spec:", train_set.element_spec)
|
43 |
+
# print("Validation set element spec:", valid_set.element_spec)
|
44 |
+
# print("Test set element spec:", test_set.element_spec)
|
45 |
+
|
46 |
+
tf.random.set_seed(42) # extra code – ensures reproducibility
|
47 |
+
|
48 |
+
input = tf.keras.Input(shape=(256, 256, 3))
|
49 |
+
inp_resized = tf.keras.layers.Lambda(lambda X: tf.image.resize(X, (299, 299)))(input)
|
50 |
+
|
51 |
+
base_model = tf.keras.applications.xception.Xception(weights="imagenet", include_top=False, input_tensor=inp_resized)
|
52 |
+
|
53 |
+
avg = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
|
54 |
+
output = tf.keras.layers.Dense(1, activation="sigmoid")(avg)
|
55 |
+
model = tf.keras.Model(inputs=base_model.input, outputs=output)
|
56 |
+
|
57 |
+
# freezing layers
|
58 |
+
for layer in base_model.layers:
|
59 |
+
layer.trainable = False
|
60 |
+
|
61 |
+
class PrintShapeCallback(tf.keras.callbacks.Callback):
|
62 |
+
def on_epoch_begin(self, epoch, logs=None):
|
63 |
+
print(f"Epoch {epoch+1}: Input shape:", self.model.input_shape)
|
64 |
+
|
65 |
+
print_shape_callback = PrintShapeCallback()
|
66 |
+
|
67 |
+
# Set the number of steps per epoch
|
68 |
+
steps_per_epoch = 5
|
69 |
+
|
70 |
+
optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
|
71 |
+
model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])
|
72 |
+
history = model.fit(train_set, steps_per_epoch=steps_per_epoch, validation_data=valid_set, epochs=10, callbacks=[print_shape_callback])
|
73 |
+
|
74 |
+
# plot model performance
|
75 |
+
acc = history.history['accuracy']
|
76 |
+
val_acc = history.history['val_accuracy']
|
77 |
+
loss = history.history['loss']
|
78 |
+
val_loss = history.history['val_loss']
|
79 |
+
epochs_range = range(1, len(history.epoch) + 1)
|
80 |
+
|
81 |
+
plt.figure(figsize=(15,5))
|
82 |
+
|
83 |
+
plt.subplot(1, 2, 1)
|
84 |
+
plt.plot(epochs_range, acc, label='Train Set')
|
85 |
+
plt.plot(epochs_range, val_acc, label='Val Set')
|
86 |
+
plt.legend(loc="best")
|
87 |
+
plt.xlabel('Epochs')
|
88 |
+
plt.ylabel('Accuracy')
|
89 |
+
plt.title('Model Accuracy')
|
90 |
+
|
91 |
+
plt.subplot(1, 2, 2)
|
92 |
+
plt.plot(epochs_range, loss, label='Train Set')
|
93 |
+
plt.plot(epochs_range, val_loss, label='Val Set')
|
94 |
+
plt.legend(loc="best")
|
95 |
+
plt.xlabel('Epochs')
|
96 |
+
plt.ylabel('Loss')
|
97 |
+
plt.title('Model Loss')
|
98 |
+
|
99 |
+
plt.tight_layout()
|
100 |
+
plt.show()
|