imdaisylee commited on
Commit
e5e0406
1 Parent(s): e68f790

Upload deepfake_model.py

Browse files
Files changed (1) hide show
  1. deepfake_model.py +100 -0
deepfake_model.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Imports
2
+ import sys
3
+ import tensorflow as tf
4
+ import cv2
5
+ import pandas as pd
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+
9
+
10
+ batch_size = 32
11
+ num_epochs = 4
12
+
13
+ # Load the dataset
14
+ data_dir = '/Users/daisy./Documents/data'
15
+ dataset = tf.keras.preprocessing.image_dataset_from_directory(data_dir)
16
+
17
+ # Split the dataset into train, validation, and test sets
18
+ train_size = int(len(dataset) * 0.7)
19
+ val_size = int(len(dataset) * 0.2)
20
+ test_size = int(len(dataset) * 0.1)
21
+
22
+ train = dataset.take(train_size)
23
+ val = dataset.skip(train_size).take(val_size)
24
+ test = dataset.skip(train_size + val_size).take(test_size)
25
+
26
+ # Preprocess the data for the Xception model
27
+ tf.keras.backend.clear_session()
28
+
29
+ preprocess = tf.keras.applications.xception.preprocess_input
30
+
31
+ # Apply preprocessing before batching
32
+ train = train.map(lambda x, y: (preprocess(x), y))
33
+ val = val.map(lambda x, y: (preprocess(x), y))
34
+ test = test.map(lambda x, y: (preprocess(x), y))
35
+
36
+ # Shuffle and batch the datasets, leaving the last batch incomplete
37
+ train_set = train.shuffle(buffer_size=len(train)).unbatch().batch(batch_size).repeat(num_epochs).prefetch(1)
38
+ valid_set = val.unbatch().batch(batch_size).repeat(num_epochs)
39
+ test_set = test.unbatch().batch(batch_size).repeat(num_epochs)
40
+
41
+ # Check the element specifications
42
+ # print("Train set element spec:", train_set.element_spec)
43
+ # print("Validation set element spec:", valid_set.element_spec)
44
+ # print("Test set element spec:", test_set.element_spec)
45
+
46
+ tf.random.set_seed(42) # extra code – ensures reproducibility
47
+
48
+ input = tf.keras.Input(shape=(256, 256, 3))
49
+ inp_resized = tf.keras.layers.Lambda(lambda X: tf.image.resize(X, (299, 299)))(input)
50
+
51
+ base_model = tf.keras.applications.xception.Xception(weights="imagenet", include_top=False, input_tensor=inp_resized)
52
+
53
+ avg = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
54
+ output = tf.keras.layers.Dense(1, activation="sigmoid")(avg)
55
+ model = tf.keras.Model(inputs=base_model.input, outputs=output)
56
+
57
+ # freezing layers
58
+ for layer in base_model.layers:
59
+ layer.trainable = False
60
+
61
+ class PrintShapeCallback(tf.keras.callbacks.Callback):
62
+ def on_epoch_begin(self, epoch, logs=None):
63
+ print(f"Epoch {epoch+1}: Input shape:", self.model.input_shape)
64
+
65
+ print_shape_callback = PrintShapeCallback()
66
+
67
+ # Set the number of steps per epoch
68
+ steps_per_epoch = 5
69
+
70
+ optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
71
+ model.compile(loss="binary_crossentropy", optimizer=optimizer, metrics=["accuracy"])
72
+ history = model.fit(train_set, steps_per_epoch=steps_per_epoch, validation_data=valid_set, epochs=10, callbacks=[print_shape_callback])
73
+
74
+ # plot model performance
75
+ acc = history.history['accuracy']
76
+ val_acc = history.history['val_accuracy']
77
+ loss = history.history['loss']
78
+ val_loss = history.history['val_loss']
79
+ epochs_range = range(1, len(history.epoch) + 1)
80
+
81
+ plt.figure(figsize=(15,5))
82
+
83
+ plt.subplot(1, 2, 1)
84
+ plt.plot(epochs_range, acc, label='Train Set')
85
+ plt.plot(epochs_range, val_acc, label='Val Set')
86
+ plt.legend(loc="best")
87
+ plt.xlabel('Epochs')
88
+ plt.ylabel('Accuracy')
89
+ plt.title('Model Accuracy')
90
+
91
+ plt.subplot(1, 2, 2)
92
+ plt.plot(epochs_range, loss, label='Train Set')
93
+ plt.plot(epochs_range, val_loss, label='Val Set')
94
+ plt.legend(loc="best")
95
+ plt.xlabel('Epochs')
96
+ plt.ylabel('Loss')
97
+ plt.title('Model Loss')
98
+
99
+ plt.tight_layout()
100
+ plt.show()