File size: 3,714 Bytes
4825f44 39cdc68 7a18648 4825f44 39cdc68 4825f44 39cdc68 d7be2f6 4825f44 39cdc68 d8baab8 4825f44 9f12b81 4825f44 9f12b81 39cdc68 85f061d 39cdc68 9f12b81 71076b0 39cdc68 85f061d 4825f44 e941ac3 4825f44 9ded5df 39cdc68 deeba82 39cdc68 560634b 39cdc68 9aac4d9 39cdc68 9aac4d9 39cdc68 9aac4d9 39cdc68 6b0f035 39cdc68 560634b 5c00344 560634b 5c00344 4825f44 5c00344 ea273f8 5e61f83 560634b 5e61f83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
# importing all the packages
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import numpy as np
import streamlit as st
import tensorflow as tf
import numpy as np
import keras
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
import streamlit as st
from streamlit_drawable_canvas import st_canvas
import cv2
from PIL import Image
import torchvision.transforms as transforms
import torch
from skorch import NeuralNetClassifier
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from model import Cnn
# Reading the data
mnist = fetch_openml('mnist_784', as_frame=False, cache=False)
X = mnist.data.astype('float32')
y = mnist.target.astype('int64')
X /= 255.0
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
XCnn = X.reshape(-1, 1, 28, 28) #reshape input
XCnn_train, XCnn_test, y_train, y_test = train_test_split(XCnn, y, test_size=0.25, random_state=42) #train test split
torch.manual_seed(0)
# reshape and train test split
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
XCnn = X.reshape(-1, 1, 28, 28)
XCnn_train, XCnn_test, y_train, y_test = train_test_split(XCnn, y, test_size=0.25, random_state=42)
from PIL import Image
import torchvision.transforms as transforms
torch.manual_seed(0)
model=Cnn()
# Specify the path to the saved model weights
model_weights_path = 'model_weights.pth'
# Load the model weights
model.load_state_dict(torch.load(model_weights_path))
# Set the model to evaluation mode for inference
model.eval()
# Create a NeuralNetClassifier using the loaded model
cnn = NeuralNetClassifier(
module=model,
max_epochs=0, # Set max_epochs to 0 to avoid additional training
lr=0.002, # You can set this to the learning rate used during training
optimizer=torch.optim.Adam, # You can set the optimizer used during training
device='cpu' # You can specify the device ('cpu' for CPU, 'cuda' for GPU, etc.)
)
cnn.fit(XCnn_train, y_train)
# Set the page title
st.title("Handwritten Text Digit Recognition")
stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
stroke_color = st.sidebar.color_picker("Stroke color hex:", "#ffffff") # Default value is white
bg_color = st.sidebar.color_picker("Background color hex:", "#000000") # Default value is black
bg_image = st.sidebar.file_uploader("Background image:", type=["png", "jpg"])
drawing_mode = st.sidebar.selectbox(
"Drawing tool:", ("freedraw", "line", "rect", "circle", "transform", "polygon")
)
realtime_update = st.sidebar.checkbox("Update in realtime", True)
#create canvas component
canvas_result = st_canvas(
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
stroke_width=stroke_width,
stroke_color=stroke_color,
background_color=bg_color,
background_image=Image.open(bg_image) if bg_image else None,
update_streamlit=realtime_update,
height=200,
width=200,
drawing_mode=drawing_mode,
display_toolbar=st.sidebar.checkbox("Display toolbar", True),
key="full_app",
)
# Do something interesting with the image data and paths
if canvas_result.image_data is not None:
image = canvas_result.image_data
image1 = image.copy()
image1 = image1.astype('uint8')
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
image1 = cv2.resize(image1, (28, 28))
st.image(image1)
# Correctly reshape the image
image1 = image1.reshape(1, 1, 28, 28).astype('float32')
prediction = cnn.predict(image1)
st.title(f"Handwritten Digit Prediction: {prediction}")
# if canvas_result.json_data is not None:
# st.dataframe(pd.json_normalize(canvas_result.json_data["objects"]))
|