eaglelandsonce commited on
Commit
f2036f2
·
verified ·
1 Parent(s): a982b06

Create 13_TransferLearning.py

Browse files
Files changed (1) hide show
  1. pages/13_TransferLearning.py +104 -0
pages/13_TransferLearning.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ from tensorflow.keras import layers, models, applications
4
+ from tensorflow.keras.preprocessing.image import ImageDataGenerator
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+
8
+ # Set dataset paths
9
+ train_dir = 'data/train'
10
+ validation_dir = 'data/validation'
11
+
12
+ # Streamlit app
13
+ st.title("Transfer Learning with VGG16 for Image Classification")
14
+
15
+ # Input parameters
16
+ batch_size = st.slider("Batch Size", 16, 128, 32, 16)
17
+ epochs = st.slider("Epochs", 5, 50, 10, 5)
18
+
19
+ # Data augmentation and preprocessing
20
+ train_datagen = ImageDataGenerator(
21
+ rescale=1./255,
22
+ rotation_range=40,
23
+ width_shift_range=0.2,
24
+ height_shift_range=0.2,
25
+ shear_range=0.2,
26
+ zoom_range=0.2,
27
+ horizontal_flip=True,
28
+ fill_mode='nearest'
29
+ )
30
+
31
+ validation_datagen = ImageDataGenerator(rescale=1./255)
32
+
33
+ train_generator = train_datagen.flow_from_directory(
34
+ train_dir,
35
+ target_size=(150, 150),
36
+ batch_size=batch_size,
37
+ class_mode='binary'
38
+ )
39
+
40
+ validation_generator = validation_datagen.flow_from_directory(
41
+ validation_dir,
42
+ target_size=(150, 150),
43
+ batch_size=batch_size,
44
+ class_mode='binary'
45
+ )
46
+
47
+ # Load the pre-trained VGG16 model
48
+ base_model = applications.VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
49
+
50
+ # Freeze the convolutional base
51
+ base_model.trainable = False
52
+
53
+ # Add custom layers on top
54
+ model = models.Sequential([
55
+ base_model,
56
+ layers.Flatten(),
57
+ layers.Dense(256, activation='relu'),
58
+ layers.Dropout(0.5),
59
+ layers.Dense(1, activation='sigmoid') # Change the output layer based on the number of classes
60
+ ])
61
+
62
+ model.summary()
63
+
64
+ # Compile the model
65
+ model.compile(optimizer='adam',
66
+ loss='binary_crossentropy', # Change loss function based on the number of classes
67
+ metrics=['accuracy'])
68
+
69
+ # Train the model
70
+ if st.button("Train Model"):
71
+ with st.spinner("Training the model..."):
72
+ history = model.fit(
73
+ train_generator,
74
+ steps_per_epoch=train_generator.samples // train_generator.batch_size,
75
+ epochs=epochs,
76
+ validation_data=validation_generator,
77
+ validation_steps=validation_generator.samples // validation_generator.batch_size
78
+ )
79
+
80
+ st.success("Model training completed!")
81
+
82
+ # Display training curves
83
+ st.subheader("Training and Validation Accuracy")
84
+ fig, ax = plt.subplots()
85
+ ax.plot(history.history['accuracy'], label='Training Accuracy')
86
+ ax.plot(history.history['val_accuracy'], label='Validation Accuracy')
87
+ ax.set_xlabel('Epoch')
88
+ ax.set_ylabel('Accuracy')
89
+ ax.legend()
90
+ st.pyplot(fig)
91
+
92
+ st.subheader("Training and Validation Loss")
93
+ fig, ax = plt.subplots()
94
+ ax.plot(history.history['loss'], label='Training Loss')
95
+ ax.plot(history.history['val_loss'], label='Validation Loss')
96
+ ax.set_xlabel('Epoch')
97
+ ax.set_ylabel('Loss')
98
+ ax.legend()
99
+ st.pyplot(fig)
100
+
101
+ # Evaluate the model
102
+ if st.button("Evaluate Model"):
103
+ test_loss, test_acc = model.evaluate(validation_generator, verbose=2)
104
+ st.write(f"Validation accuracy: {test_acc}")