11Abdul's picture
New update
85d930d
import gradio as gr
from PIL import Image
import torch
from torchvision import transforms
from model import AlexNet # Assuming you have defined AlexNet in model.py
# Load your pre-trained AlexNet model
model = AlexNet()
model.load_state_dict(torch.load("./alexnet_model_v1.pth", map_location=torch.device('cpu')))
model.eval()
# Define transformations for input images
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Define class labels
labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# Define function to perform inference on an image
def predict_image(image):
# Preprocess the image
image = Image.fromarray(image.astype('uint8'), 'RGB')
image = preprocess(image).unsqueeze(0)
# Perform inference
with torch.no_grad():
output = model(image)
probabilities = torch.softmax(output, dim=1)[0]
predicted_classes = torch.argsort(output, descending=True)[0]
# Get class labels and probabilities
top_classes = [labels[idx] for idx in predicted_classes[:10]]
top_probabilities = probabilities[predicted_classes[:10]].tolist()
return {class_name: probability for class_name, probability in zip(top_classes, top_probabilities)}
# Create Gradio interface
inputs = gr.components.Image()
outputs = gr.components.Label(num_top_classes=10)
interface = gr.Interface(fn=predict_image, inputs=inputs, outputs=outputs, title="AlexNet Image Classifier")
# Launch the Gradio interface
interface.launch()