bean_app / app.py
konneker's picture
add torch import
011c740
import torch
import datasets
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import gradio as gr
description = "Example bean leaf health classification app using a fine-tuned ViT model"
title_string = "What's up with my leaf?"
dataset = datasets.load_dataset("beans")
extractor = AutoFeatureExtractor.from_pretrained("saved_model_files")
model = AutoModelForImageClassification.from_pretrained("saved_model_files")
labels = dataset['train'].features['labels'].names
def classify(im):
features = extractor(im, return_tensors='pt')
encoding = model(**features)
logits = torch.nn.functional.softmax(encoding.logits, dim=-1)
probability = torch.nn.functional.softmax(logits, dim=-1)
probs = probability[0].detach().numpy()
confidences = {label: float(probs[i]) for i, label in enumerate(labels)}
return confidences
example_paths = ['angular-leaf-spot-1.JPG', 'bean-rust.jpg']
interface = gr.Interface(
classify,
gr.inputs.Image(),
gr.outputs.Label(num_top_classes=3),
description=description,
examples=example_paths,
title = title_string
)
interface.launch()