autoencoders / autoencoders_MNIST_denoiser.py
makiisthebes's picture
Upload 6 files
0743e7c verified
# Autoencoder usages for denoising MNIST dataset.
# Learning how to use autoencoder for a denoising task.
# This will later be used in my project for denoising sensor data of MOS for gas sensing.
# YouTube Resource Credits: https://www.youtube.com/watch?v=Sm54KXD-L1k
# We are going to encode an image and then decode it to see if the image is denoised, in which the bottleneck layer will loss some information,
# which will be used to learn the features of the image.
# Noise Reduction is a common problem, in which autoencoders are used to solve this problem by reconstructing the image.
# We will be using Keras and Tensorflow for this task.
from keras.datasets import mnist
from tensorflow.keras.layers import Conv2D, MaxPooling2D, UpSampling2D
from tensorflow.keras.models import Sequential
import numpy as np
import random, cv2
from matplotlib import pyplot as plt
def sp_noise(image,prob):
'''
Add salt and pepper noise to image
prob: Probability of the noise
'''
output = np.zeros(image.shape,np.uint8)
thres = 1 - prob
for i in range(image.shape[0]):
for j in range(image.shape[1]):
rdn = random.random()
if rdn < prob:
output[i][j] = 0
elif rdn > thres:
output[i][j] = 255
else:
output[i][j] = image[i][j]
return output
(x_train, _), (x_test, _) = mnist.load_data() # We are not interested in the labels. 28x28x1 images.
# Same as normal face constructor, we simply test based on the MNIST and (transformed) blurred dataset
# Convert to 0-1 range
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0
# Reshape the dataset
x_train = x_train.reshape((len(x_train), 28, 28, 1))
x_test = x_test.reshape((len(x_test), 28, 28, 1))
# Add noise to the dataset
# Guassin noise factor.
noise_factor = 0.5
x_train_noisy = x_train + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_train.shape)
x_test_noisy = x_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=x_test.shape)
# Clip the dataset to 0-1 range
x_train_noisy = np.clip(x_train_noisy, 0., 1.)
x_test_noisy = np.clip(x_test_noisy, 0., 1.)
#Displaying images with noise
# plt.figure(figsize=(20, 2))
# for i in range(1,10):
# ax = plt.subplot(1, 10, i)
# plt.imshow(x_test_noisy[i].reshape(28, 28), cmap="binary")
# plt.show()
SIZE = 28
DIMENSIONS = 1
model = Sequential()
# Encoder Layer
model.add(Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(SIZE, SIZE, DIMENSIONS)))
model.add(MaxPooling2D((2, 2), padding='same')) # Shrinking
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D((2, 2), padding='same')) # Shrinking
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D((2, 2), padding='same')) # Bottleneck
model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(MaxPooling2D((2, 2), padding='same')) # Bottleneck
model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
# Middle Layer
model.add(MaxPooling2D((2, 2), padding='same')) # Shrinking
# Decoder Layer
model.add(Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2))) # Expanding
model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2))) # Expanding
model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2))) # Expanding
model.add(Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2))) # Expanding
model.add(Conv2D(32, (3, 3), activation='relu', padding='same'))
model.add(UpSampling2D((2, 2))) # Expanding
# Output Layer
model.add(Conv2D(DIMENSIONS, (5, 5), activation='relu'))
model.compile(optimizer='adam', loss='mean_squared_error', metrics=['accuracy'])
model.summary()
model.fit(x=x_train_noisy, y=x_train, epochs=3, batch_size=256, shuffle=True, validation_data=(x_test_noisy, x_test))
model.evaluate(x_test_noisy, x_test)
# Save model and show model outputs.
model.save('denoising_autoencoder.model')
no_noise_img = model.predict(x_test_noisy)
plt.figure(figsize=(40, 4))
for i in range(10):
# display original
ax = plt.subplot(3, 20, i + 1)
plt.imshow(x_test_noisy[i].reshape(28, 28), cmap="binary")
# display reconstructed (after noise removed) image
ax = plt.subplot(3, 20, 40 + i + 1)
plt.imshow(no_noise_img[i].reshape(28, 28), cmap="binary")
plt.waitforbuttonpress()