|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import gradio as gr |
|
|
|
|
|
with torch.no_grad(): |
|
model = torch.load('classifier.pt') |
|
|
|
|
|
def preprocess(image): |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
image = Image.fromarray(image.astype('uint8'), 'RGB') |
|
image = transform(image) |
|
return image.unsqueeze(0) |
|
|
|
|
|
def predict(image): |
|
|
|
input_tensor = preprocess(image) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
|
|
|
|
|
|
probabilities = torch.softmax(output, dim=1).squeeze().tolist() |
|
|
|
|
|
class_labels = ["Class1", "Class2", "Class3", "Class4"] |
|
|
|
|
|
predictions = {label: prob for label, prob in zip(class_labels, probabilities)} |
|
|
|
return predictions |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(), |
|
outputs=gr.Label(num_top_classes=4), |
|
live=True |
|
) |
|
|
|
|
|
iface.launch() |
|
|