digit-recognition / model.py
danilommarano's picture
Script for digit recognition model
25a44a2
from pathlib import Path
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Normalize the pixel values to range [0, 1]
x_train = x_train / 255.0
x_test = x_test / 255.0
# Reshape the data to 4D (number of samples, height, width, channels)
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
# Create the model
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
MaxPooling2D((2, 2)),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D((2, 2)),
Flatten(),
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
# Compile the model
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Train the model
model.fit(x_train, y_train, epochs=5, batch_size=32, validation_split=0.1)
# Save the model
path = Path(Path(__file__).parent, 'model.h5')
model.save(path)