cat-vs-dog / src /models.py
eddydecena's picture
Add: Source
e4497d1
from typing import Tuple
from typing import Optional
import tensorflow as tf
from tensorflow.keras import layers
from keras_tuner import HyperModel
class MakeHyperModel(HyperModel):
def __init__(self, input_shape: Tuple[int, int, int], num_classes: int, data_augmentation: Optional[tf.keras.Sequential] = None) -> None:
self.input_shape = input_shape
self.num_classes = num_classes
self.data_augmentation = data_augmentation
def build(self, hp) -> tf.keras.Model:
inputs = tf.keras.Input(shape=self.input_shape)
if self.data_augmentation != None:
x = self.data_augmentation(inputs)
else:
x = inputs
x = layers.Rescaling(1.0/255)(x)
x = layers.Conv2D(32, 3, strides=2, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.Conv2D(64, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
previous_block_activation = x
for size in [128, 256, 512, 728]:
x = layers.Activation('relu')(x)
x = layers.SeparableConv2D(size, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(size, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(3, strides=2, padding='same')(x)
residual = layers.Conv2D(size, 1, strides=2, padding='same')(previous_block_activation)
x = layers.add([x, residual])
previous_block_activation = x
x = layers.SeparableConv2D(1024, 3, padding='same')(x)
x = layers.BatchNormalization()(x)
x = layers.Activation('relu')(x)
x = layers.GlobalAveragePooling2D()(x)
if self.num_classes == 2:
activation = 'sigmoid'
loss_fn = 'binary_crossentropy'
units = 1
else:
activation = 'softmax'
loss_fn = 'categorical_crossentropy'
units = self.num_classes
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(units, activation=activation)(x)
model = tf.keras.Model(inputs, outputs)
model.compile(
optimizer=tf.keras.optimizers.Adam(
hp.Choice("learning_rate", values=[1e-2, 1e-3, 1e-4])
),
loss=loss_fn,
metrics=['accuracy']
)
return model