Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
from model import AlexNet # Assuming you have defined AlexNet in model.py | |
# Load your pre-trained AlexNet model | |
model = AlexNet() | |
model.load_state_dict(torch.load("./alexnet_model_v1.pth", map_location=torch.device('cpu'))) | |
model.eval() | |
# Define transformations for input images | |
preprocess = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
# Define class labels | |
labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
# Define function to perform inference on an image | |
def predict_image(image): | |
# Preprocess the image | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
image = preprocess(image).unsqueeze(0) | |
# Perform inference | |
with torch.no_grad(): | |
output = model(image) | |
probabilities = torch.softmax(output, dim=1)[0] | |
predicted_classes = torch.argsort(output, descending=True)[0] | |
# Get class labels and probabilities | |
top_classes = [labels[idx] for idx in predicted_classes[:10]] | |
top_probabilities = probabilities[predicted_classes[:10]].tolist() | |
return {class_name: probability for class_name, probability in zip(top_classes, top_probabilities)} | |
# Create Gradio interface | |
inputs = gr.components.Image() | |
outputs = gr.components.Label(num_top_classes=10) | |
interface = gr.Interface(fn=predict_image, inputs=inputs, outputs=outputs, title="AlexNet Image Classifier") | |
# Launch the Gradio interface | |
interface.launch() | |