import gradio as gr import pandas as pd import numpy as np from transformers import pipeline, ConvNextForImageClassification, ConvNextFeatureExtractor, ViTForImageClassification, ViTFeatureExtractor, AutoFeatureExtractor, ResNetForImageClassification from PIL import Image #load the models convnext_model = ConvNextForImageClassification.from_pretrained('convnext') convnext_feature_extractor = ConvNextFeatureExtractor.from_pretrained('facebook/convnext-tiny-224') convnext_clf = pipeline("image-classification", model = convnext_model, feature_extractor = convnext_feature_extractor) vit_model = ViTForImageClassification.from_pretrained('vit') vit_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k') vit_clf = pipeline("image-classification", model = vit_model, feature_extractor = vit_feature_extractor) resnet_model = ResNetForImageClassification.from_pretrained('resnet') resnet_feature_extractor = AutoFeatureExtractor.from_pretrained('microsoft/resnet-50') resnet_clf = pipeline("image-classification", model = resnet_model, feature_extractor = resnet_feature_extractor) #define the functions def convnext_classify(image): convnext_scores = convnext_clf(image) return convnext_scores def resnet_classify(image): resnet_scores = resnet_clf(image) return resnet_scores def vit_classify(image): vit_scores = vit_clf(image) return vit_scores def classify(image): convnext_scores = convnext_classify(image) resnet_scores = resnet_classify(image) vit_scores = vit_classify(image) score_dict = {} results = [resnet_scores, vit_scores, convnext_scores] for result in results: for item in result: item['label'] = item['label'].replace('_', ' ') if item['label'] not in score_dict: score_dict[item['label']] = 0 score_dict[item['label']] += item['score'] / 3 return score_dict with gr.Blocks() as demo: gr.Markdown('# Rice Disease Classifier') gr.Markdown('Rice is one of the most popular crops in the world, and is an especially important staple in developing countries. Farmers may find it useful to quickly classify what disease or pest is affecting their crops. This app allows for a picture of a rice plant to be quickly uploaded and classified. The app uses an ensemble of three pre-trained models - a Google Vision Transformer, a ConvNeXT model, and Microsoft\'s Resnet 50. These models were then fine-tuned on images of healthy and diseased rice found at https://www.kaggle.com/competitions/paddy-disease-classification.') gr.Markdown('Please note that for best results images should show detail of the plant. Images of large fields are unlikely to show enough detail for the model to identify a disease.') inputs=gr.Image(type="pil") outputs=gr.Label() image_button = gr.Button("Classify") image_button.click(classify, inputs=inputs, outputs=outputs), gr.Markdown("## Image Examples") with gr.Row(): gr.Examples( examples=['rice-blast.jfif','rice-deadheart.jfif','rice-healthy.jfif','rice-hispa.jpg'], inputs = inputs) demo.launch()