Spaces:
Runtime error
Runtime error
import io | |
import gradio as gr | |
import requests, validators | |
import torch | |
import pathlib | |
from PIL import Image | |
import datasets | |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
import os | |
import IPython | |
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" | |
feature_extractor = AutoFeatureExtractor.from_pretrained("saved_model_files") | |
model = AutoModelForImageClassification.from_pretrained("saved_model_files") | |
labels = ['angular_leaf_spot', 'bean_rust', 'healthy'] | |
def classify(im): | |
'''FUnction for classifying plant health status''' | |
features = feature_extractor(im, return_tensors='pt') | |
with torch.no_grad(): | |
logits = model(**features).logits | |
probability = torch.nn.functional.softmax(logits, dim=-1) | |
probs = probability[0].detach().numpy() | |
confidences = {label: float(probs[i]) for i, label in enumerate(labels)} | |
return confidences | |
def get_original_image(url_input): | |
'''Get image from URL''' | |
if validators.url(url_input): | |
image = Image.open(requests.get(url_input, stream=True).raw) | |
return image | |
def detect_plant_health(url_input,image_input,webcam_input): | |
if validators.url(url_input): | |
image = Image.open(requests.get(url_input, stream=True).raw) | |
elif image_input: | |
image = image_input | |
elif webcam_input: | |
image = webcam_input | |
#Make prediction | |
label_probs = classify(image) | |
return label_probs | |
def set_example_image(example: list) -> dict: | |
return gr.Image.update(value=example[0]) | |
def set_example_url(example: list) -> dict: | |
return gr.Textbox.update(value=example[0]), gr.Image.update(value=get_original_image(example[0])) | |
title = """<h1 id="title">Plant Health Classification with ViT</h1>""" | |
description = """ | |
This Plant Health classifier app was built to detect the health of plants using images of leaves by fine-tuning a Vision Transformer (ViT) [google/vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224) on the [Beans](https://huggingface.co/datasets/beans) dataset. | |
The finetuned model has an accuracy of 98.4% on the test (unseen) dataset and 100% on the validation dataset. | |
How to use the app: | |
- Upload an image via 3 options, uploading the image from local device, using a URL (image from the web) or a webcam | |
- The app will take a few seconds to generate a prediction with the following labels: | |
- *angular_leaf_spot* | |
- *bean_rust* | |
- *healthy* | |
- Feel free to click the image examples as well. | |
""" | |
urls = ["https://www.healthbenefitstimes.com/green-beans/","https://huggingface.co/nateraw/vit-base-beans/resolve/main/angular_leaf_spot.jpeg", "https://huggingface.co/nateraw/vit-base-beans/resolve/main/bean_rust.jpeg"] | |
images = [[path.as_posix()] for path in sorted(pathlib.Path('images').rglob('*.p*g'))] | |
twitter_link = """ | |
[![](https://img.shields.io/twitter/follow/nickmuchi?label=@nickmuchi&style=social)](https://twitter.com/nickmuchi) | |
""" | |
css = ''' | |
h1#title { | |
text-align: center; | |
} | |
''' | |
demo = gr.Blocks(css=css) | |
with demo: | |
gr.Markdown(title) | |
gr.HTML('<center><img src="file/images/Healthy.png" width=350px height=350px></center>') | |
gr.Markdown(description) | |
gr.Markdown(twitter_link) | |
with gr.Tabs(): | |
with gr.TabItem('Image Upload'): | |
with gr.Row(): | |
with gr.Column(): | |
img_input = gr.Image(type='pil',shape=(450,450)) | |
label_from_upload= gr.Label(num_top_classes=3) | |
with gr.Row(): | |
example_images = gr.Examples(examples=images,inputs=[img_input]) | |
img_but = gr.Button('Classify') | |
with gr.TabItem('Image URL'): | |
with gr.Row(): | |
with gr.Column(): | |
url_input = gr.Textbox(lines=2,label='Enter valid image URL here..') | |
original_image = gr.Image(shape=(450,450)) | |
url_input.change(get_original_image, url_input, original_image) | |
with gr.Column(): | |
label_from_url = gr.Label(num_top_classes=3) | |
with gr.Row(): | |
example_url = gr.Examples(examples=urls,inputs=[url_input]) | |
url_but = gr.Button('Classify') | |
with gr.TabItem('WebCam'): | |
with gr.Row(): | |
with gr.Column(): | |
web_input = gr.Image(source='webcam',type='pil',shape=(450,450),streaming=True) | |
with gr.Column(): | |
label_from_webcam= gr.Label(num_top_classes=3) | |
cam_but = gr.Button('Classify') | |
url_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_url],queue=True) | |
img_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_upload],queue=True) | |
cam_but.click(detect_plant_health,inputs=[url_input,img_input,web_input],outputs=[label_from_webcam],queue=True) | |
gr.Markdown("![visitor badge](https://visitor-badge.glitch.me/badge?page_id=nickmuchi-plant-health)") | |
demo.launch(debug=True,enable_queue=True) |