File size: 2,443 Bytes
2d15ff4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import numpy as np
from tensorflow import keras
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# Загрузка датасета Fashion MNIST
(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()

# Нормализация данных
X_train_full = X_train_full.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.

# Разбиение на тренировочную, валидационную и тестовую части
X_train, X_val = train_test_split(X_train_full, test_size=0.2, random_state=42)
X_val, X_test = train_test_split(X_val, test_size=0.5, random_state=42)

# Создание нейросети
input_shape = X_train.shape[1:]
latent_dim = 50
autoencoder = keras.models.Sequential([
    keras.layers.Flatten(input_shape=input_shape),
    keras.layers.Dense(256, activation='relu'),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(latent_dim, activation='relu', name='latent_layer'),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(256, activation='relu'),
    keras.layers.Dense(np.prod(input_shape), activation='sigmoid'),
    keras.layers.Reshape(input_shape)
])

# Компиляция и обучение.# binary_crossentropy может и не нужна, но тема рабочая, менять не буду.
autoencoder.compile(loss='binary_crossentropy', optimizer='adam',metrics=['accuracy'])
history = autoencoder.fit(X_train, X_train,
                          epochs=50,
                          batch_size=128,
                          validation_data=(X_val, X_val))

# Визуализация результатов.  Я лично доволен. Оно работает!!!
n = 7 # количество изображений для примера
decoded_imgs = autoencoder.predict(X_test[:n]) # Кодировка и декодировка тестовых изображений
plt.figure(figsize=(10, 4.5))
for i in range(n):
    # Оригинальное изображение
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(X_test[i])
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # Декодированное изображение
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i])
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)