nickmuchi commited on
Commit
a561d8f
1 Parent(s): f7f1e12

Create new file

Browse files
Files changed (1) hide show
  1. app.py +140 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ import requests, validators
5
+ import torch
6
+ import pathlib
7
+ from PIL import Image
8
+ import datasets
9
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
10
+ import os
11
+
12
+
13
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
14
+
15
+ feature_extractor = AutoFeatureExtractor.from_pretrained("/content/drive/MyDrive/Week 1 Project/saved_model_files")
16
+ model = AutoModelForImageClassification.from_pretrained("/content/drive/MyDrive/Week 1 Project/saved_model_files")
17
+
18
+ labels = ['angular_leaf_spot', 'bean_rust', 'healthy']
19
+
20
+ def classify(im):
21
+ '''FUnction for classifying plant health status'''
22
+
23
+ features = feature_extractor(im, return_tensors='pt')
24
+ with torch.no_grad():
25
+ logits = model(**features).logits
26
+ probability = torch.nn.functional.softmax(logits, dim=-1)
27
+ probs = probability[0].detach().numpy()
28
+ confidences = {label: float(probs[i]) for i, label in enumerate(labels)}
29
+
30
+ return confidences
31
+
32
+ def get_original_image(url_input):
33
+ '''Get image from URL'''
34
+ if validators.url(url_input):
35
+
36
+ image = Image.open(requests.get(url_input, stream=True).raw)
37
+
38
+ return image
39
+
40
+ def detect_plant_health(url_input,image_input,webcam_input):
41
+
42
+ if validators.url(url_input):
43
+ image = Image.open(requests.get(url_input, stream=True).raw)
44
+
45
+ elif image_input:
46
+ image = image_input
47
+
48
+ elif webcam_input:
49
+ image = webcam_input
50
+
51
+ #Make prediction
52
+ label_probs = classify(image)
53
+
54
+ return label_probs
55
+
56
+ def set_example_image(example: list) -> dict:
57
+ return gr.Image.update(value=example[0])
58
+
59
+ def set_example_url(example: list) -> dict:
60
+ return gr.Textbox.update(value=example[0]), gr.Image.update(value=get_original_image(example[0]))
61
+
62
+
63
+ title = """<h1 id="title">Plant Health Classification with ViT</h1>"""
64
+
65
+ description = """
66
+ This Plant Health classifier app was built to detect the health of plants using images of leaves by fine-tuning a Vision Transformer (ViT) [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) on the [Beans](https://huggingface.co/datasets/beans) dataset.
67
+ The finetuned model has an accuracy of 98.4% on the test (unseen) dataset and 100% on the validation dataset.
68
+
69
+ How to use the app:
70
+ - Upload an image via 3 options, uploading the image from local device, using a URL (image from the web) or a webcam
71
+ - The app will take a few seconds to generate a prediction with the following labels:
72
+ - *'angular_leaf_spot'*
73
+ - *'bean_rust'*
74
+ - *'healthy'*
75
+ - Feel free to click the image examples as well.
76
+ """
77
+ urls = ["https://www.healthbenefitstimes.com/green-beans/","https://huggingface.co/nateraw/vit-base-beans/resolve/main/angular_leaf_spot.jpeg", "https://huggingface.co/nateraw/vit-base-beans/resolve/main/bean_rust.jpeg"]
78
+ images = [[path.as_posix()] for path in sorted(pathlib.Path('images').rglob('*.j*g'))]
79
+
80
+ twitter_link = """
81
+ [![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi)
82
+ """
83
+
84
+ css = '''
85
+ h1#title {
86
+ text-align: center;
87
+ }
88
+ '''
89
+ demo = gr.Blocks(css=css)
90
+
91
+ with demo:
92
+ gr.Markdown(title)
93
+ gr.Markdown(description)
94
+ gr.Markdown(twitter_link)
95
+
96
+ with gr.Tabs():
97
+ with gr.TabItem('Image Upload'):
98
+ with gr.Row():
99
+ with gr.Column():
100
+ img_input = gr.Image(type='pil',shape=(750,750))
101
+ label_from_upload= gr.Label()
102
+
103
+ with gr.Row():
104
+ example_images = gr.Examples(examples=images,inputs=[img_input])
105
+
106
+
107
+ img_but = gr.Button('Classify')
108
+
109
+ with gr.TabItem('Image URL'):
110
+ with gr.Row():
111
+ with gr.Column():
112
+ url_input = gr.Textbox(lines=2,label='Enter valid image URL here..')
113
+ original_image = gr.Image(shape=(750,750))
114
+ url_input.change(get_original_image, url_input, original_image)
115
+ with gr.Column():
116
+ label_from_url = gr.Label()
117
+
118
+ with gr.Row():
119
+ example_url = gr.Examples(examples=urls,inputs=[url_input])
120
+
121
+
122
+ url_but = gr.Button('Classify')
123
+
124
+ with gr.TabItem('WebCam'):
125
+ with gr.Row():
126
+ with gr.Column():
127
+ web_input = gr.Image(source='webcam',type='pil',shape=(750,750),streaming=True)
128
+ with gr.Column():
129
+ label_from_webcam= gr.Label()
130
+
131
+ cam_but = gr.Button('Classify')
132
+
133
+ url_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_url],queue=True)
134
+ img_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_upload],queue=True)
135
+ cam_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_webcam],queue=True)
136
+
137
+ gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-plant-health)")
138
+
139
+
140
+ demo.launch(debug=True,enable_queue=True)