import gradio as gr from transformers import ViTForImageClassification import torch from PIL import Image import torchvision.transforms as transforms # Load the model model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") model.eval() # Define the image preprocessing pipeline transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ]) def predict_image(img): # Apply the transformations tensor_img = transform(img).unsqueeze(0) # Make prediction with torch.no_grad(): outputs = model(tensor_img) predictions = outputs.logits.argmax(-1) return model.config.id2label[predictions.item()] # Create the interface iface = gr.Interface( fn=predict_image, inputs=gr.Image(shape=(224, 224)), outputs="text", live=True, capture_session=True, title="Image recognition", description="Upload an image you want to categorize.", theme="Monochrome" ) iface.launch()