import gradio as gr import torchvision.transforms as transforms from torchvision import models from PIL import Image # Load a pre-trained ResNet model model = models.resnet50(pretrained=True) model.eval() transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) # Define a function to classify an image def classify_image(input_image): img = Image.open(input_image) img = transform(img).unsqueeze(0) with torch.no_grad(): outputs = model(img) _, predicted_class = outputs.max(1) return class_names[predicted_class.item()] # Create a Gradio interface iface = gr.Interface( fn=classify_image, inputs=gr.inputs.Image(type="file", label="Upload an Image"), outputs=gr.outputs.Textbox(label="Predicted Class"), live=True, theme="default", title="Image Classification with ResNet", ) # Launch the Gradio interface iface.launch()