Gosula commited on
Commit
d8baab8
1 Parent(s): 156d0d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -19
app.py CHANGED
@@ -58,25 +58,25 @@ torch.manual_seed(0)
58
  # lr=0.002,
59
  # optimizer=torch.optim.Adam,
60
  # device=device,
61
- # )
62
- model=Cnn()
 
 
 
 
 
 
 
 
63
 
64
  # Specify the path to the saved model weights
65
  model_weights_path = 'model_weights.pth'
66
 
67
- # Load the model weights
68
- model.load_state_dict(torch.load(model_weights_path,map_location=torch.device('cpu')))
69
 
70
  # Set the model to evaluation mode for inference
71
  model.eval()
72
- # Create a NeuralNetClassifier using the loaded model
73
- cnn = NeuralNetClassifier(
74
- module=model,
75
- max_epochs=0, # Set max_epochs to 0 to avoid additional training
76
- lr=0.002, # You can set this to the learning rate used during training
77
- optimizer=torch.optim.Adam, # You can set the optimizer used during training
78
- device='cpu' # You can specify the device ('cpu' for CPU, 'cuda' for GPU, etc.)
79
- )
80
 
81
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
82
  stroke_color = st.sidebar.color_picker("Stroke color hex: ")
@@ -103,16 +103,23 @@ canvas_result = st_canvas(
103
 
104
  # Do something interesting with the image data and paths
105
  if canvas_result.image_data is not None:
106
- #st.image(canvas_result.image_data)
107
  image = canvas_result.image_data
108
  image1 = image.copy()
109
  image1 = image1.astype('uint8')
110
- image1 = cv2.cvtColor(image1,cv2.COLOR_BGR2GRAY)
111
- image1 = cv2.resize(image1,(28,28))
112
  st.image(image1)
113
 
114
- image1.resize(1,1,28,28)
115
- st.title(np.argmax(cnn.predict(image1)))
116
- if canvas_result.json_data is not None:
117
- st.dataframe(pd.json_normalize(canvas_result.json_data["objects"]))
 
 
 
 
 
 
 
 
118
 
 
58
  # lr=0.002,
59
  # optimizer=torch.optim.Adam,
60
  # device=device,
61
+ import streamlit as st
62
+ from st_canvas import st_canvas
63
+ import torch
64
+ from PIL import Image
65
+ import cv2
66
+ import numpy as np
67
+ from your_model_module import Cnn # Import your model architecture
68
+
69
+ # Create an instance of your model (Cnn model)
70
+ model = Cnn()
71
 
72
  # Specify the path to the saved model weights
73
  model_weights_path = 'model_weights.pth'
74
 
75
+ # Load the model weights onto a CPU device (if you want to use the CPU)
76
+ model.load_state_dict(torch.load(model_weights_path, map_location=torch.device('cpu')))
77
 
78
  # Set the model to evaluation mode for inference
79
  model.eval()
 
 
 
 
 
 
 
 
80
 
81
  stroke_width = st.sidebar.slider("Stroke width: ", 1, 35, 32)
82
  stroke_color = st.sidebar.color_picker("Stroke color hex: ")
 
103
 
104
  # Do something interesting with the image data and paths
105
  if canvas_result.image_data is not None:
 
106
  image = canvas_result.image_data
107
  image1 = image.copy()
108
  image1 = image1.astype('uint8')
109
+ image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
110
+ image1 = cv2.resize(image1, (28, 28))
111
  st.image(image1)
112
 
113
+ # Convert the image for prediction (assuming image1 is in the right format)
114
+ image1 = image1[np.newaxis, np.newaxis, ...] # Add batch and channel dimensions
115
+
116
+ # Perform prediction using the pre-trained model
117
+ with torch.no_grad():
118
+ tensor_image = torch.tensor(image1, dtype=torch.float32)
119
+ prediction = model(tensor_image)
120
+
121
+ # Display the predicted class
122
+ predicted_class = prediction.argmax().item()
123
+ st.title(f"Predicted Class: {predicted_class}")
124
+
125