Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from torchvision import transforms | |
| from model import AlexNet # Assuming you have defined AlexNet in model.py | |
| # Load your pre-trained AlexNet model | |
| model = AlexNet() | |
| model.load_state_dict(torch.load("./alexnet_model_v1.pth", map_location=torch.device('cpu'))) | |
| model.eval() | |
| # Define transformations for input images | |
| preprocess = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # Define class labels | |
| labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
| # Define function to perform inference on an image | |
| def predict_image(image): | |
| # Preprocess the image | |
| image = Image.fromarray(image.astype('uint8'), 'RGB') | |
| image = preprocess(image).unsqueeze(0) | |
| # Perform inference | |
| with torch.no_grad(): | |
| output = model(image) | |
| probabilities = torch.softmax(output, dim=1)[0] | |
| predicted_classes = torch.argsort(output, descending=True)[0] | |
| # Get class labels and probabilities | |
| top_classes = [labels[idx] for idx in predicted_classes[:10]] | |
| top_probabilities = probabilities[predicted_classes[:10]].tolist() | |
| return {class_name: probability for class_name, probability in zip(top_classes, top_probabilities)} | |
| # Create Gradio interface | |
| inputs = gr.components.Image() | |
| outputs = gr.components.Label(num_top_classes=10) | |
| interface = gr.Interface(fn=predict_image, inputs=inputs, outputs=outputs, title="AlexNet Image Classifier") | |
| # Launch the Gradio interface | |
| interface.launch() | |