jeffhaines commited on
Commit
67a135c
1 Parent(s): b4e4366

Create new file

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ from transformers import pipeline, ConvNextForImageClassification, ConvNextFeatureExtractor, ViTForImageClassification, ViTFeatureExtractor, AutoFeatureExtractor, ResNetForImageClassification
5
+ from PIL import Image
6
+
7
+ #load the models
8
+ convnext_model = ConvNextForImageClassification.from_pretrained('convnext-rice')
9
+ convnext_feature_extractor = ConvNextFeatureExtractor.from_pretrained('facebook/convnext-tiny-224')
10
+ convnext_clf = pipeline("image-classification", model = convnext_model, feature_extractor = convnext_feature_extractor)
11
+
12
+ vit_model = ViTForImageClassification.from_pretrained('vit-rice')
13
+ vit_feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
14
+ vit_clf = pipeline("image-classification", model = vit_model, feature_extractor = vit_feature_extractor)
15
+
16
+ resnet_model = ResNetForImageClassification.from_pretrained('resnet-rice')
17
+ resnet_feature_extractor = AutoFeatureExtractor.from_pretrained('microsoft/resnet-50')
18
+ resnet_clf = pipeline("image-classification", model = resnet_model, feature_extractor = resnet_feature_extractor)
19
+
20
+ #define the functions
21
+ def convnext_classify(image):
22
+ convnext_scores = convnext_clf(image)
23
+ return convnext_scores
24
+
25
+ def resnet_classify(image):
26
+ resnet_scores = resnet_clf(image)
27
+ return resnet_scores
28
+
29
+ def vit_classify(image):
30
+ vit_scores = vit_clf(image)
31
+ return vit_scores
32
+
33
+ def classify(image):
34
+ convnext_scores = convnext_classify(image)
35
+ resnet_scores = resnet_classify(image)
36
+ vit_scores = vit_classify(image)
37
+
38
+ score_dict = {}
39
+ results = [resnet_scores, vit_scores, convnext_scores]
40
+ for result in results:
41
+ for item in result:
42
+ item['label'] = item['label'].replace('_', ' ')
43
+ if item['label'] not in score_dict:
44
+ score_dict[item['label']] = 0
45
+ score_dict[item['label']] += item['score'] / 3
46
+
47
+ return score_dict
48
+
49
+ with gr.Blocks() as demo:
50
+ gr.Markdown('# Rice Disease Classifier')
51
+ gr.Markdown('Rice is one of the most popular crops in the world, and is an especially important staple in developing countries. Farmers may find it useful to quickly classify what disease or pest is affecting their crops. This app allows for a picture of a rice plant to be quickly uploaded and classified. The app uses an ensemble of three pre-trained models - a Google Vision Transformer, a ConvNeXT model, and Microsoft\'s Resnet 50. These models were then fine-tuned on images of healthy and diseased rice found at https://www.kaggle.com/competitions/paddy-disease-classification.')
52
+ gr.Markdown('Please note that for best results images should show detail of the plant. Images of large fields are unlikely to show enough detail for the model to identify a disease.')
53
+
54
+ inputs=gr.Image(type="pil")
55
+ outputs=gr.Label()
56
+
57
+ image_button = gr.Button("Classify")
58
+
59
+ image_button.click(classify, inputs=inputs, outputs=outputs),
60
+
61
+ gr.Markdown("## Image Examples")
62
+ with gr.Row():
63
+ gr.Examples(
64
+ examples=['rice-blast.jfif','rice-deadheart.jfif','rice-healthy.jfif','rice-hispa.jpg'], inputs = inputs)
65
+
66
+ demo.launch()