I2VGen-XL / test.py
kevinwang676's picture
Update test.py
34a163e verified
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Activation
from tensorflow.keras import regularizers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
with tf.device('/gpu:0'):
encoding_dim1 = 500
encoding_dim2 = 200
lambda_act = 0.0001
lambda_weight = 0.001
input_data = Input(shape=(num_in_neurons,))
# Encoder
encoded = Dense(encoding_dim1, activity_regularizer=regularizers.l1(lambda_act), kernel_regularizer=regularizers.l2(lambda_weight), name='encoder1')(input_data)
encoded = BatchNormalization()(encoded)
encoded = Activation('relu')(encoded)
encoded = Dropout(0.5)(encoded)
encoded = Dense(encoding_dim2, activity_regularizer=regularizers.l1(lambda_act), kernel_regularizer=regularizers.l2(lambda_weight), name='encoder2')(encoded)
encoded = BatchNormalization()(encoded)
encoded = Activation('relu')(encoded)
encoded = Dropout(0.5)(encoded)
# Decoder
decoded = Dense(encoding_dim1, name='decoder1')(encoded)
decoded = BatchNormalization()(decoded)
decoded = Activation('relu')(decoded)
decoded = Dense(num_in_neurons, name='decoder2')(decoded)
decoded = Activation('sigmoid')(decoded)
autoencoder = Model(inputs=input_data, outputs=decoded)
autoencoder.compile(optimizer=Adam(learning_rate=0.001), loss='mse')
# Callbacks
callbacks = [
EarlyStopping(monitor='val_loss', patience=10, verbose=1),
ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, verbose=1),
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1)
]
# Training
print('Training the autoencoder')
autoencoder.fit(x_train_noisy, x_train,
epochs=50,
batch_size=16, # Adjusted batch size
shuffle=True,
validation_data=(x_test_noisy, x_test),
callbacks=callbacks)
# Load best model
autoencoder.load_weights('best_model.h5')
# Freeze the weights for inference
autoencoder.trainable = False