File size: 3,283 Bytes
d7ff06f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import glob
import shutil
import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator

# init necessary variables
BATCH_SIZE = 32
IMG_SIZE = (224, 224)
IMG_SHAPE = IMG_SIZE + (3,)
num_classes = 5

# 1. Download dataset
_URL = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
path_to_zip = tf.keras.utils.get_file(origin=_URL, fname="flower_photos.tgz", extract=True)
base_dir = os.path.join(os.path.dirname(path_to_zip), 'flower_photos')
classes = ['roses', 'daisy', 'dandelion', 'sunflowers', 'tulips']

# create train and validation set
for cl in classes:
    img_path = os.path.join(base_dir, cl)
    images = glob.glob(img_path + '/*.jpg')
    # print("{}: {} Images".format(cl, len(images)))
    num_train = int(round(len(images)*0.8))
    train, val = images[:num_train], images[num_train:]

    for t in train:
        if not os.path.exists(os.path.join(base_dir, 'train', cl)):
            os.makedirs(os.path.join(base_dir, 'train', cl))
        shutil.move(t, os.path.join(base_dir, 'train', cl))

    for v in val:
        if not os.path.exists(os.path.join(base_dir, 'val', cl)):
            os.makedirs(os.path.join(base_dir, 'val', cl))
        shutil.move(v, os.path.join(base_dir, 'val', cl))

train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')

# 2. Data augmentation
image_gen = ImageDataGenerator(tf.keras.applications.mobilenet_v2.preprocess_input,
                               rescale=1./255, rotation_range=45,
                               width_shift_range=0.2, height_shift_range=0.2,
                               shear_range=0.3, zoom_range=0.5, horizontal_flip=True)

train_data_gen = image_gen.flow_from_directory(batch_size=BATCH_SIZE, directory=train_dir,
                                               shuffle=True, target_size=IMG_SIZE, class_mode='sparse')

val_data_gen = image_gen.flow_from_directory(batch_size=BATCH_SIZE, directory=val_dir,
                                             shuffle=True, target_size=IMG_SIZE, class_mode='sparse')

# 3. Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(input_shape=IMG_SHAPE,
                                                            weights='imagenet', include_top=False)
# Freeze base model
base_model.trainable = False

# Connect new predict output to base model
x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(1024, activation='relu')(x)
x = tf.keras.layers.Dropout(0.2)(x)
x = tf.keras.layers.Dense(512, activation='relu')(x)
outputs = tf.keras.layers.Dense(num_classes)(x)
model = tf.keras.Model(base_model.inputs, outputs)

# Set up the learning process
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# 4. Train
epochs = 10
# save best model
checkpoint = tf.keras.callbacks.ModelCheckpoint('flowers/best.h5', monitor='val_loss',
                                                save_best_only=True, mode='auto')
callback_list = [checkpoint]

history = model.fit(train_data_gen, epochs=10, validation_data=val_data_gen, callbacks=callback_list)