File size: 3,271 Bytes
67a135c
 
 
 
 
 
 
c344d3c
67a135c
 
 
c344d3c
67a135c
 
 
c344d3c
67a135c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
import pandas as pd
import numpy as np
from transformers import pipeline, ConvNextForImageClassification, ConvNextFeatureExtractor, ViTForImageClassification, ViTFeatureExtractor, AutoFeatureExtractor, ResNetForImageClassification
from PIL import Image

#load the models
convnext_model = ConvNextForImageClassification.from_pretrained('convnext')
convnext_feature_extractor = ConvNextFeatureExtractor.from_pretrained('facebook/convnext-tiny-224')
convnext_clf = pipeline("image-classification", model = convnext_model, feature_extractor = convnext_feature_extractor)

vit_model = ViTForImageClassification.from_pretrained('vit')
vit_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
vit_clf = pipeline("image-classification", model = vit_model, feature_extractor = vit_feature_extractor)

resnet_model = ResNetForImageClassification.from_pretrained('resnet')
resnet_feature_extractor = AutoFeatureExtractor.from_pretrained('microsoft/resnet-50')
resnet_clf = pipeline("image-classification", model = resnet_model, feature_extractor = resnet_feature_extractor)

#define the functions
def convnext_classify(image):
    convnext_scores = convnext_clf(image)
    return convnext_scores                               
                                   
def resnet_classify(image):
    resnet_scores = resnet_clf(image)
    return resnet_scores
    
def vit_classify(image):
    vit_scores = vit_clf(image)
    return vit_scores
    
def classify(image):
    convnext_scores = convnext_classify(image)
    resnet_scores = resnet_classify(image)
    vit_scores = vit_classify(image)
    
    score_dict = {}
    results = [resnet_scores, vit_scores, convnext_scores]
    for result in results:
        for item in result:
            item['label'] = item['label'].replace('_', ' ')
            if item['label'] not in score_dict:
                score_dict[item['label']] = 0
            score_dict[item['label']] += item['score'] / 3
        
    return score_dict
  
with gr.Blocks() as demo:
    gr.Markdown('# Rice Disease Classifier')
    gr.Markdown('Rice is one of the most popular crops in the world, and is an especially important staple in developing countries. Farmers may find it useful to quickly classify what disease or pest is affecting their crops. This app allows for a picture of a rice plant to be quickly uploaded and classified. The app uses an ensemble of three pre-trained models - a Google Vision Transformer, a ConvNeXT model, and Microsoft\'s Resnet 50. These models were then fine-tuned on images of healthy and diseased rice found at https://www.kaggle.com/competitions/paddy-disease-classification.')
    gr.Markdown('Please note that for best results images should show detail of the plant. Images of large fields are unlikely to show enough detail for the model to identify a disease.')
    
    inputs=gr.Image(type="pil")
    outputs=gr.Label()
    
    image_button = gr.Button("Classify")

    image_button.click(classify, inputs=inputs, outputs=outputs),
    
    gr.Markdown("## Image Examples")
    with gr.Row():
        gr.Examples(
            examples=['rice-blast.jfif','rice-deadheart.jfif','rice-healthy.jfif','rice-hispa.jpg'], inputs = inputs)
    
    demo.launch()