import gradio as gr from PIL import Image from transformers import ViTImageProcessor, ViTForImageClassification import torch # Load the image processor and model processor = ViTImageProcessor.from_pretrained('wambugu1738/crop_leaf_diseases_vit') model = ViTForImageClassification.from_pretrained( 'wambugu1738/crop_leaf_diseases_vit', ignore_mismatched_sizes=True ) # Define a function to make predictions def classify_image(image): inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() return model.config.id2label[predicted_class_idx] # Create the Gradio interface app = gr.Interface( fn=classify_image, inputs=gr.Image(type="numpy"), # Corrected input type outputs="text" ) # Launch the Gradio app with a public link app.launch(share=True)