| | import tensorflow as tf |
| | import os |
| | import matplotlib.pyplot as plt |
| | from dotenv import load_dotenv |
| | from tensorflow.keras.preprocessing.image import ImageDataGenerator |
| |
|
| | load_dotenv() |
| |
|
| | |
| | BATCH_SIZE = 32 |
| | IMG_SIZE = (224, 224) |
| | TRAIN_DATASET = os.getenv("TRAIN_DATASET") |
| | EPOCHS = 8 |
| | OPTIMIZER = 'adam' |
| | LOSS_FUNC = 'binary_crossentropy' |
| |
|
| | |
| | def load_data(): |
| | datagen = ImageDataGenerator( |
| | validation_split=0.2, |
| | rescale=1./255, |
| | horizontal_flip=True, |
| | zoom_range=0.2 |
| | ) |
| | |
| | train_data = datagen.flow_from_directory( |
| | directory=TRAIN_DATASET, |
| | target_size=IMG_SIZE, |
| | batch_size=BATCH_SIZE, |
| | class_mode="binary", |
| | subset="training", |
| | shuffle=True |
| | ) |
| | |
| | val_data = datagen.flow_from_directory( |
| | directory=TRAIN_DATASET, |
| | target_size=IMG_SIZE, |
| | batch_size=BATCH_SIZE, |
| | class_mode="binary", |
| | subset="validation", |
| | shuffle=True |
| | ) |
| | |
| | return train_data, val_data |
| |
|
| | |
| | def build_model(): |
| | model = tf.keras.Sequential([ |
| | |
| | |
| | |
| | tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(*IMG_SIZE, 3)), |
| | tf.keras.layers.MaxPooling2D(2,2), |
| | |
| | |
| | |
| | tf.keras.layers.Conv2D(64, (3,3), activation='relu'), |
| | tf.keras.layers.MaxPooling2D(2,2), |
| | |
| | |
| | |
| | tf.keras.layers.Conv2D(128, (3,3), activation='relu'), |
| | tf.keras.layers.MaxPooling2D(2,2), |
| | tf.keras.layers.Flatten(), |
| | tf.keras.layers.Dense(512, activation='relu'), |
| | tf.keras.layers.Dense(1, activation='sigmoid') |
| | ]) |
| | |
| | |
| | model.compile(optimizer=OPTIMIZER, |
| | loss=LOSS_FUNC, |
| | metrics=['accuracy']) |
| | |
| | return model |
| |
|
| | |
| | def main(): |
| | train_data, val_data = load_data() |
| | model = build_model() |
| | |
| | |
| | history = model.fit( |
| | train_data, |
| | epochs = EPOCHS, |
| | validation_data=val_data |
| | ) |
| | |
| | |
| | model.save("cat_dog_model.h5") |
| |
|
| | |
| | acc = history.history['accuracy'] |
| | loss = history.history['loss'] |
| | val_acc = history.history['val_accuracy'] |
| | val_loss = history.history['val_loss'] |
| | plt.plot(acc, label='Train Accuracy') |
| | plt.plot(val_acc, label='Validation Accuracy') |
| | plt.plot(loss, label='Train Loss') |
| | plt.plot(val_loss, label='Validation Loss') |
| | plt.legend() |
| | plt.title('Training Accuracy') |
| | plt.show() |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|