XXXX / app.py
Nvd's picture
Update app.py
64d20d7
import torch
import torchvision.transforms as transforms
import gradio as gr
from PIL import Image
from model import SimpleCNN
def preprocess_image(image):
transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = Image.fromarray(image)
# Ensure that the image has three channels (RGB)
if image.mode != 'RGB':
image = image.convert('RGB')
image = transform(image)
image = image.unsqueeze(0)
return image
def predict_image(model, image):
with torch.no_grad():
output = model(image)
_, predicted = torch.max(output.data, 1)
return predicted.item()
def main():
model = SimpleCNN()
model.load_state_dict(torch.load('cifar10_model.pth'))
# Set the model to evaluation mode
model.eval()
iface = gr.Interface(
fn=lambda img: predict_image(model, preprocess_image(img)),
inputs=gr.Image(),
outputs="label",
live=True,
)
iface.launch()
if __name__ == "__main__":
main()