# import libraries from PIL import Image from torchvision import models, transforms import torch import torch.nn as nn import torch.nn.functional as F import streamlit as st # set title of app st.title("What number is it?") st.write("") # enable users to upload images for the model to make predictions file_up = st.file_uploader("Upload an image") # Define a more complex neural network model class ComplexModel(nn.Module): def __init__(self, input_channels, output_size): super(ComplexModel, self).__init__() self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) self.fc1 = nn.Linear(64 * 7 * 7, 128) self.fc2 = nn.Linear(128, output_size) def forward(self, x): x = F.relu(self.conv1(x)) x = F.max_pool2d(x, kernel_size=2, stride=2) x = F.relu(self.conv2(x)) x = F.max_pool2d(x, kernel_size=2, stride=2) x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.fc2(x) x = F.log_softmax(x, dim=1) return x model = ComplexModel(1, 10) def predict(image): # load the model model.load_state_dict(torch.load('model-all-digit.pth')) model.eval() # define how we transform the input image through resizing transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # load the image, pre-process it img = Image.open(image) img = img.convert('L') # Convert the image to grayscale batch_t = torch.unsqueeze(transform(img), 0) # make the prediction output = model(batch_t) predicted_label = torch.argmax(output, dim=1).item() return predicted_label if file_up is not None: # display image that user uploaded image = Image.open(file_up) st.image(image, caption = 'Uploaded Image.', use_column_width = True) st.write("") labels = predict(file_up) # print out the predicted digit st.header(f"Prediction: {labels}")