Spaces:
Runtime error
Runtime error
import tensorflow as tf | |
from tensorflow.keras.models import Model | |
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Activation | |
from tensorflow.keras import regularizers | |
from tensorflow.keras.optimizers import Adam | |
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau | |
with tf.device('/gpu:0'): | |
encoding_dim1 = 500 | |
encoding_dim2 = 200 | |
lambda_act = 0.0001 | |
lambda_weight = 0.001 | |
input_data = Input(shape=(num_in_neurons,)) | |
# Encoder | |
encoded = Dense(encoding_dim1, activity_regularizer=regularizers.l1(lambda_act), kernel_regularizer=regularizers.l2(lambda_weight), name='encoder1')(input_data) | |
encoded = BatchNormalization()(encoded) | |
encoded = Activation('relu')(encoded) | |
encoded = Dropout(0.5)(encoded) | |
encoded = Dense(encoding_dim2, activity_regularizer=regularizers.l1(lambda_act), kernel_regularizer=regularizers.l2(lambda_weight), name='encoder2')(encoded) | |
encoded = BatchNormalization()(encoded) | |
encoded = Activation('relu')(encoded) | |
encoded = Dropout(0.5)(encoded) | |
# Decoder | |
decoded = Dense(encoding_dim1, name='decoder1')(encoded) | |
decoded = BatchNormalization()(decoded) | |
decoded = Activation('relu')(decoded) | |
decoded = Dense(num_in_neurons, name='decoder2')(decoded) | |
decoded = Activation('sigmoid')(decoded) | |
autoencoder = Model(inputs=input_data, outputs=decoded) | |
autoencoder.compile(optimizer=Adam(learning_rate=0.001), loss='mse') | |
# Callbacks | |
callbacks = [ | |
EarlyStopping(monitor='val_loss', patience=10, verbose=1), | |
ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True, verbose=1), | |
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1) | |
] | |
# Training | |
print('Training the autoencoder') | |
autoencoder.fit(x_train_noisy, x_train, | |
epochs=50, | |
batch_size=16, # Adjusted batch size | |
shuffle=True, | |
validation_data=(x_test_noisy, x_test), | |
callbacks=callbacks) | |
# Load best model | |
autoencoder.load_weights('best_model.h5') | |
# Freeze the weights for inference | |
autoencoder.trainable = False | |