YukiNon commited on
Commit
2d15ff4
1 Parent(s): 70c9dff

Create practice.py

Browse files
Files changed (1) hide show
  1. practice.py +55 -0
practice.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from tensorflow import keras
3
+ from sklearn.model_selection import train_test_split
4
+ import matplotlib.pyplot as plt
5
+
6
+ # Загрузка датасета Fashion MNIST
7
+ (X_train_full, y_train_full), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()
8
+
9
+ # Нормализация данных
10
+ X_train_full = X_train_full.astype('float32') / 255.
11
+ X_test = X_test.astype('float32') / 255.
12
+
13
+ # Разбиение на тренировочную, валидационную и тестовую части
14
+ X_train, X_val = train_test_split(X_train_full, test_size=0.2, random_state=42)
15
+ X_val, X_test = train_test_split(X_val, test_size=0.5, random_state=42)
16
+
17
+ # Создание нейросети
18
+ input_shape = X_train.shape[1:]
19
+ latent_dim = 50
20
+ autoencoder = keras.models.Sequential([
21
+ keras.layers.Flatten(input_shape=input_shape),
22
+ keras.layers.Dense(256, activation='relu'),
23
+ keras.layers.Dense(128, activation='relu'),
24
+ keras.layers.Dense(latent_dim, activation='relu', name='latent_layer'),
25
+ keras.layers.Dense(128, activation='relu'),
26
+ keras.layers.Dense(256, activation='relu'),
27
+ keras.layers.Dense(np.prod(input_shape), activation='sigmoid'),
28
+ keras.layers.Reshape(input_shape)
29
+ ])
30
+
31
+ # Компиляция и обучение.# binary_crossentropy может и не нужна, но тема рабочая, менять не буду.
32
+ autoencoder.compile(loss='binary_crossentropy', optimizer='adam',metrics=['accuracy'])
33
+ history = autoencoder.fit(X_train, X_train,
34
+ epochs=50,
35
+ batch_size=128,
36
+ validation_data=(X_val, X_val))
37
+
38
+ # Визуализация результатов. Я лично доволен. Оно работает!!!
39
+ n = 7 # количество изображений для примера
40
+ decoded_imgs = autoencoder.predict(X_test[:n]) # Кодировка и декодировка тестовых изображений
41
+ plt.figure(figsize=(10, 4.5))
42
+ for i in range(n):
43
+ # Оригинальное изображение
44
+ ax = plt.subplot(2, n, i + 1)
45
+ plt.imshow(X_test[i])
46
+ plt.gray()
47
+ ax.get_xaxis().set_visible(False)
48
+ ax.get_yaxis().set_visible(False)
49
+
50
+ # Декодированное изображение
51
+ ax = plt.subplot(2, n, i + 1 + n)
52
+ plt.imshow(decoded_imgs[i])
53
+ plt.gray()
54
+ ax.get_xaxis().set_visible(False)
55
+ ax.get_yaxis().set_visible(False)