import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import gradio as gr # Load your trained model with torch.no_grad(): model = torch.load('classifier.pt') # Define the preprocessing function for the input image def preprocess(image): transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) image = Image.fromarray(image.astype('uint8'), 'RGB') image = transform(image) return image.unsqueeze(0) # Define the predict function def predict(image): # Preprocess the image input_tensor = preprocess(image) # Make a prediction with torch.no_grad(): output = model(input_tensor) # Perform post-processing if needed (e.g., softmax for probabilities) # Replace this with your actual post-processing logic probabilities = torch.softmax(output.logits, dim=1).squeeze().tolist() # Map the class indices to class labels class_labels = ["Cat", "Dog", "Horse", "Monkey"] # Create a dictionary with class labels and probabilities predictions = {label: prob for label, prob in zip(class_labels, probabilities)} return predictions # Create the Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(), outputs=gr.Label(num_top_classes=4), live=True ) # Launch the Gradio app iface.launch(quiet=True)