cat-vs-dog / train.py
eddydecena's picture
First model version
dbeb7c3
from tensorflow.keras.utils import get_file
import tensorflow as tf
from keras_tuner import RandomSearch
from keras_tuner import Objective
from src.preprocessing import delete_corrupted_image
from src.draw import visualize_data
from src.preprocessing import get_data_augmentation
from src.models import MakeHyperModel
from src.config import DATASET_URL
from src.config import CACHE_DIR
from src.config import CACHE_SUBDIR
from src.config import DATASET_PATH
from src.config import IMAGE_SIZE
from src.config import BATCH_SIZE
from src.config import EPOCHS
get_file(origin=DATASET_URL, extract=True, cache_dir=CACHE_DIR, cache_subdir=CACHE_SUBDIR)
print(delete_corrupted_image(DATASET_PATH, ('Cat', 'Dog')))
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
DATASET_PATH,
validation_split=0.2,
subset='training',
seed=1337,
image_size=IMAGE_SIZE,
batch_size=BATCH_SIZE
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
DATASET_PATH,
validation_split=0.2,
subset='validation',
seed=1337,
image_size=IMAGE_SIZE,
batch_size=BATCH_SIZE
)
train_ds = train_ds.prefetch(buffer_size=BATCH_SIZE)
val_ds = val_ds.prefetch(buffer_size=BATCH_SIZE)
data_augmentation = get_data_augmentation()
visualize_data(train_ds, data_augmentation=data_augmentation)
hypermodel = MakeHyperModel(input_shape=IMAGE_SIZE + (3,), num_classes=2, data_augmentation=data_augmentation)
tuner = RandomSearch(
hypermodel,
objective=Objective("val_accuracy", direction="max"),
max_trials=3,
executions_per_trial=2,
overwrite=True,
directory='tuner_model',
project_name='cat-vs-dog'
)
tuner.search_space_summary()
tuner.search(train_ds, epochs=EPOCHS, validation_data=val_ds)
tuner.get_best_hyperparameters()