|
import datetime |
|
import os |
|
from typing import List |
|
import absl |
|
import keras_tuner |
|
import tensorflow as tf |
|
from tensorflow.keras.optimizers import Adam |
|
import tensorflow_transform as tft |
|
|
|
from tensorflow_cloud import CloudTuner |
|
from tfx.v1.components import TunerFnResult |
|
from tfx.components.trainer.fn_args_utils import DataAccessor |
|
from tfx.components.trainer.fn_args_utils import FnArgs |
|
from tfx.dsl.io import fileio |
|
from tfx_bsl.tfxio import dataset_options |
|
import tfx.extensions.google_cloud_ai_platform.constants as vertex_const |
|
import tfx.extensions.google_cloud_ai_platform.trainer.executor as vertex_training_const |
|
import tfx.extensions.google_cloud_ai_platform.tuner.executor as vertex_tuner_const |
|
|
|
_TRAIN_DATA_SIZE = 128 |
|
_EVAL_DATA_SIZE = 128 |
|
_TRAIN_BATCH_SIZE = 32 |
|
_EVAL_BATCH_SIZE = 32 |
|
_CLASSIFIER_LEARNING_RATE = 1e-3 |
|
_FINETUNE_LEARNING_RATE = 7e-6 |
|
_CLASSIFIER_EPOCHS = 30 |
|
|
|
_IMAGE_KEY = "image" |
|
_LABEL_KEY = "label" |
|
|
|
|
|
def INFO(text: str): |
|
absl.logging.info(text) |
|
|
|
|
|
def _transformed_name(key: str) -> str: |
|
return key + "_xf" |
|
|
|
|
|
def _get_signature(model): |
|
signatures = { |
|
"serving_default": _get_serve_image_fn(model).get_concrete_function( |
|
tf.TensorSpec( |
|
shape=[None, 224, 224, 3], |
|
dtype=tf.float32, |
|
name=_transformed_name(_IMAGE_KEY), |
|
) |
|
) |
|
} |
|
|
|
return signatures |
|
|
|
|
|
def _get_serve_image_fn(model): |
|
@tf.function |
|
def serve_image_fn(image_tensor): |
|
return model(image_tensor) |
|
|
|
return serve_image_fn |
|
|
|
|
|
def _image_augmentation(image_features): |
|
batch_size = tf.shape(image_features)[0] |
|
image_features = tf.image.random_flip_left_right(image_features) |
|
image_features = tf.image.resize_with_crop_or_pad(image_features, 250, 250) |
|
image_features = tf.image.random_crop(image_features, (batch_size, 224, 224, 3)) |
|
return image_features |
|
|
|
|
|
def _data_augmentation(feature_dict): |
|
image_features = feature_dict[_transformed_name(_IMAGE_KEY)] |
|
image_features = _image_augmentation(image_features) |
|
feature_dict[_transformed_name(_IMAGE_KEY)] = image_features |
|
return feature_dict |
|
|
|
|
|
def _input_fn( |
|
file_pattern: List[str], |
|
data_accessor: DataAccessor, |
|
tf_transform_output: tft.TFTransformOutput, |
|
is_train: bool = False, |
|
batch_size: int = 200, |
|
) -> tf.data.Dataset: |
|
dataset = data_accessor.tf_dataset_factory( |
|
file_pattern, |
|
dataset_options.TensorFlowDatasetOptions( |
|
batch_size=batch_size, label_key=_transformed_name(_LABEL_KEY) |
|
), |
|
tf_transform_output.transformed_metadata.schema, |
|
) |
|
|
|
if is_train: |
|
dataset = dataset.map(lambda x, y: (_data_augmentation(x), y)) |
|
|
|
return dataset |
|
|
|
|
|
def _get_hyperparameters() -> keras_tuner.HyperParameters: |
|
hp = keras_tuner.HyperParameters() |
|
hp.Choice("learning_rate", [1e-3, 1e-2], default=1e-3) |
|
return hp |
|
|
|
|
|
def _build_keras_model(hparams: keras_tuner.HyperParameters) -> tf.keras.Model: |
|
base_model = tf.keras.applications.ResNet50( |
|
input_shape=(224, 224, 3), include_top=False, weights="imagenet", pooling="max" |
|
) |
|
base_model.input_spec = None |
|
base_model.trainable = False |
|
|
|
model = tf.keras.Sequential( |
|
[ |
|
tf.keras.layers.InputLayer( |
|
input_shape=(224, 224, 3), name=_transformed_name(_IMAGE_KEY) |
|
), |
|
base_model, |
|
tf.keras.layers.Dropout(0.1), |
|
tf.keras.layers.Dense(10, activation="softmax"), |
|
] |
|
) |
|
|
|
model.compile( |
|
loss="sparse_categorical_crossentropy", |
|
optimizer=Adam(learning_rate=hparams.get("learning_rate")), |
|
metrics=["sparse_categorical_accuracy"], |
|
) |
|
model.summary(print_fn=INFO) |
|
|
|
return model |
|
|
|
|
|
def cloud_tuner_fn(fn_args: FnArgs) -> TunerFnResult: |
|
TUNING_ARGS_KEY = vertex_tuner_const.TUNING_ARGS_KEY |
|
TRAINING_ARGS_KEY = vertex_training_const.TRAINING_ARGS_KEY |
|
VERTEX_PROJECT_KEY = "project" |
|
VERTEX_REGION_KEY = "region" |
|
|
|
tuner = CloudTuner( |
|
_build_keras_model, |
|
max_trials=6, |
|
hyperparameters=_get_hyperparameters(), |
|
project_id=fn_args.custom_config[TUNING_ARGS_KEY][VERTEX_PROJECT_KEY], |
|
region=fn_args.custom_config[TUNING_ARGS_KEY][VERTEX_REGION_KEY], |
|
objective="val_sparse_categorical_accuracy", |
|
directory=fn_args.working_dir, |
|
) |
|
|
|
tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path) |
|
|
|
train_dataset = _input_fn( |
|
fn_args.train_files, |
|
fn_args.data_accessor, |
|
tf_transform_output, |
|
is_train=True, |
|
batch_size=_TRAIN_BATCH_SIZE, |
|
) |
|
|
|
eval_dataset = _input_fn( |
|
fn_args.eval_files, |
|
fn_args.data_accessor, |
|
tf_transform_output, |
|
is_train=False, |
|
batch_size=_EVAL_BATCH_SIZE, |
|
) |
|
|
|
return TunerFnResult( |
|
tuner=tuner, |
|
fit_kwargs={ |
|
"x": train_dataset, |
|
"validation_data": eval_dataset, |
|
"steps_per_epoch": steps_per_epoch, |
|
"validation_steps": fn_args.eval_steps, |
|
}, |
|
) |
|
|
|
|
|
def tuner_fn(fn_args: FnArgs) -> TunerFnResult: |
|
steps_per_epoch = int(_TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE) |
|
|
|
tuner = keras_tuner.RandomSearch( |
|
_build_keras_model, |
|
max_trials=6, |
|
hyperparameters=_get_hyperparameters(), |
|
allow_new_entries=False, |
|
objective=keras_tuner.Objective("val_sparse_categorical_accuracy", "max"), |
|
directory=fn_args.working_dir, |
|
project_name="img_classification_tuning", |
|
) |
|
|
|
tf_transform_output = tft.TFTransformOutput(fn_args.transform_graph_path) |
|
|
|
train_dataset = _input_fn( |
|
fn_args.train_files, |
|
fn_args.data_accessor, |
|
tf_transform_output, |
|
is_train=True, |
|
batch_size=_TRAIN_BATCH_SIZE, |
|
) |
|
|
|
eval_dataset = _input_fn( |
|
fn_args.eval_files, |
|
fn_args.data_accessor, |
|
tf_transform_output, |
|
is_train=False, |
|
batch_size=_EVAL_BATCH_SIZE, |
|
) |
|
|
|
return TunerFnResult( |
|
tuner=tuner, |
|
fit_kwargs={ |
|
"x": train_dataset, |
|
"validation_data": eval_dataset, |
|
"steps_per_epoch": steps_per_epoch, |
|
"validation_steps": fn_args.eval_steps, |
|
}, |
|
) |
|
|
|
|
|
def run_fn(fn_args: FnArgs): |
|
steps_per_epoch = int(_TRAIN_DATA_SIZE / _TRAIN_BATCH_SIZE) |
|
total_epochs = int(fn_args.train_steps / steps_per_epoch) |
|
if _CLASSIFIER_EPOCHS > total_epochs: |
|
raise ValueError("Classifier epochs is greater than the total epochs") |
|
|
|
tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) |
|
|
|
train_dataset = _input_fn( |
|
fn_args.train_files, |
|
fn_args.data_accessor, |
|
tf_transform_output, |
|
is_train=True, |
|
batch_size=_TRAIN_BATCH_SIZE, |
|
) |
|
|
|
eval_dataset = _input_fn( |
|
fn_args.eval_files, |
|
fn_args.data_accessor, |
|
tf_transform_output, |
|
is_train=False, |
|
batch_size=_EVAL_BATCH_SIZE, |
|
) |
|
|
|
INFO("Tensorboard logging to {}".format(fn_args.model_run_dir)) |
|
tensorboard_callback = tf.keras.callbacks.TensorBoard( |
|
log_dir=fn_args.model_run_dir, update_freq="batch" |
|
) |
|
|
|
if fn_args.hyperparameters: |
|
hparams = keras_tuner.HyperParameters.from_config(fn_args.hyperparameters) |
|
else: |
|
hparams = _get_hyperparameters() |
|
INFO(f"HyperParameters for training: ${hparams.get_config()}") |
|
|
|
model = _build_keras_model(hparams) |
|
model.fit( |
|
train_dataset, |
|
epochs=_CLASSIFIER_EPOCHS, |
|
steps_per_epoch=steps_per_epoch, |
|
validation_data=eval_dataset, |
|
validation_steps=fn_args.eval_steps, |
|
callbacks=[tensorboard_callback], |
|
) |
|
|
|
model.save( |
|
fn_args.serving_model_dir, save_format="tf", signatures=_get_signature(model) |
|
) |
|
|