mnist-classify / app.py
Zai
deploy
e34d96b
raw
history blame contribute delete
691 Bytes
import gradio as gr
from model import Net, predict
import torch
import torchvision.transforms as transforms
from PIL import Image
model = Net()
model.load_state_dict(torch.load("mnist_model.pth", map_location=torch.device("cpu")))
model.eval()
transform = transforms.Compose([
transforms.Grayscale(), # Convert to grayscale if needed
transforms.Resize((28, 28)), # Fix: pass size as a tuple
transforms.ToTensor() # Convert to a PyTorch tensor
])
def predict_image(image):
input_tensors = transform(Image.fromarray(image)).unsqueeze(0)
result = predict(model,input_tensors)
return result
app = gr.Interface(predict_image, gr.Image(), "text")
app.launch()