faceyacc's picture
added requirements.txt
a502dd9
raw history blame
No virus
1.21 kB
import transformers
import gradio as gr
import datasets
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from transformers import ViTFeatureExtractor, ViTForImageClassification
dataset = load_dataset('beans', 'full_size')
extractor = AutoFeatureExtractor.from_pretrained("saved_model_files")
model = AutoModelForImageClassification.from_pretrained("saved_model_files")
labels = dataset['train'].features['labels'].names
def classify(im):
features = feature_extractor(im, return_tensors='pt')
logits = model(features["pixel_values"])[-1]
probability = torch.nn.functional.softmax(logits, dim=-1)
probs = probability[0].detach().numpy()
confidences = {label: float(probs[i]) for i, label in enumerate(labels)}
return confidences
description = "Bean leaf health classification wit Google's ViT"
title = "Bean Leaf Health Check"
examples = [["'angular_leaf_spot': 0.9999030828475952, 'bean_rust': 5.320278796716593e-05, 'healthy': 4.378804806037806e-05"]]
gr_interface = gr.Interface(classify, inputs='image', outputs='label', title='Bean Classification', description='Monitor your crops health in easier way')
gr_interface.launch(debug=True)