fashion-attribute-detection / image-attribute.py
Sanjeev Malla
First commit
4bca390
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Set the paths to your DeepFashion dataset
train_data_dir = 'data/train'
validation_data_dir = 'data/validation'
#
# Set the number of classes and batch size
num_classes = 50
batch_size = 32
#
# Data augmentation and normalization
train_datagen = ImageDataGenerator(
rescale=1.0 / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True
)
validation_datagen = ImageDataGenerator(rescale=1.0 / 255)
#
# Load the ResNet50 model without the top layer
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# Add a global average pooling layer
x = base_model.output
x = GlobalAveragePooling2D()(x)
# Add a fully connected layer with 1024 units
x = Dense(1024, activation='relu')(x)
# Add the final output layer with the number of classes
predictions = Dense(num_classes, activation='sigmoid')(x)
# Create the model
model = Model(inputs=base_model.input, outputs=predictions)
# Freeze the base model layers
for layer in base_model.layers:
layer.trainable = False
# Compile the model
model.compile(optimizer=Adam(lr=0.001), loss='binary_crossentropy', metrics=['accuracy'])
# Load and preprocess the training and validation data
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(224, 224),
batch_size=batch_size,
class_mode='categorical'
)
validation_generator = validation_datagen.flow_from_directory(
validation_data_dir,
target_size=(224, 224),
batch_size=batch_size,
class_mode='categorical'
)
# Train the model
model.fit(
train_generator,
steps_per_epoch=train_generator.samples,
epochs=10,
validation_data=validation_generator,
validation_steps=validation_generator.samples
)
# Save the trained model
model.save('deepfashion_attribute_model.h5')