HandCode / train_model.py
aikanava's picture
Upload 3 files
6412482 verified
raw
history blame
2.86 kB
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import matplotlib.pyplot as plt
# === CONFIGURATION ===
DATA_DIR = 'asl_alphabet_train' # Folder with A-Z subfolders containing images
MODEL_SAVE_PATH = 'trained_model/asl_model.h5'
IMG_SIZE = 64
BATCH_SIZE = 32
EPOCHS = 20
NUM_CLASSES = 26
# Create output directories if they don't exist
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
os.makedirs('outputs', exist_ok=True)
# === DATA GENERATORS ===
train_datagen = ImageDataGenerator(
rescale=1./255,
validation_split=0.2,
rotation_range=15,
zoom_range=0.1,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True
)
train_generator = train_datagen.flow_from_directory(
DATA_DIR,
target_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='training',
shuffle=True,
seed=42
)
validation_generator = train_datagen.flow_from_directory(
DATA_DIR,
target_size=(IMG_SIZE, IMG_SIZE),
batch_size=BATCH_SIZE,
class_mode='categorical',
subset='validation',
shuffle=False,
seed=42
)
# === MODEL ARCHITECTURE ===
model = Sequential([
Conv2D(32, (3,3), activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 3)),
MaxPooling2D(2,2),
Conv2D(64, (3,3), activation='relu'),
MaxPooling2D(2,2),
Conv2D(128, (3,3), activation='relu'),
MaxPooling2D(2,2),
Flatten(),
Dense(128, activation='relu'),
Dropout(0.5),
Dense(NUM_CLASSES, activation='softmax')
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.summary()
# === CALLBACKS ===
checkpoint = ModelCheckpoint(MODEL_SAVE_PATH, save_best_only=True, monitor='val_accuracy', mode='max')
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
# === TRAINING ===
history = model.fit(
train_generator,
validation_data=validation_generator,
epochs=EPOCHS,
callbacks=[checkpoint, early_stop]
)
# === PLOT TRAINING HISTORY ===
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.legend()
plt.title('Accuracy')
plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.legend()
plt.title('Loss')
plt.savefig('outputs/training_plot.png')
plt.show()