|
import torch |
|
import torchvision.transforms as transforms |
|
import gradio as gr |
|
from PIL import Image |
|
from model import SimpleCNN |
|
|
|
def preprocess_image(image): |
|
transform = transforms.Compose([ |
|
transforms.Resize((32, 32)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
image = Image.fromarray(image) |
|
|
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
image = transform(image) |
|
image = image.unsqueeze(0) |
|
return image |
|
|
|
def predict_image(model, image): |
|
with torch.no_grad(): |
|
output = model(image) |
|
|
|
_, predicted = torch.max(output.data, 1) |
|
return predicted.item() |
|
|
|
def main(): |
|
model = SimpleCNN() |
|
model.load_state_dict(torch.load('cifar10_model.pth')) |
|
|
|
|
|
model.eval() |
|
|
|
iface = gr.Interface( |
|
fn=lambda img: predict_image(model, preprocess_image(img)), |
|
inputs=gr.Image(), |
|
outputs="label", |
|
live=True, |
|
) |
|
|
|
iface.launch() |
|
|
|
if __name__ == "__main__": |
|
main() |