File size: 1,714 Bytes
0eae57d
 
 
 
ada2f3f
0eae57d
f6f1dea
0eae57d
e505593
 
0eae57d
 
8d3d679
 
0eae57d
 
 
 
 
8d3d679
0eae57d
 
 
b626dcd
 
0eae57d
 
 
b671049
69447b1
 
 
 
 
c478e88
 
b671049
c478e88
69447b1
0eae57d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import datasets
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import numpy as np
import gradio as gr
import torch

dataset = datasets.load_dataset("beans") # This should be the same as the first line of Python code in this Colab notebook

extractor = AutoFeatureExtractor.from_pretrained("andresgtn/vit-base-bean-health-classifier")
model = AutoModelForImageClassification.from_pretrained("andresgtn/vit-base-bean-health-classifier")

# add to cuda?
#model.eval()
#model.to(device)

labels = dataset['train'].features['labels'].names

def classify(im):
  features = extractor(im, return_tensors='pt')
  #features.to(device) # move to gpu as model, if available
  with torch.no_grad():
    logits = model(**features).logits
  probability = torch.nn.functional.softmax(logits, dim=-1)
  #probs = probability[0].to('cpu').detach().numpy()
  probs = probability[0].detach().numpy()
  confidences = {label: float(probs[i]) for i, label in enumerate(labels)}
  return confidences

#interface = gr.Interface(classify, gr.Image(shape=(200, 200)), 'text')

sample_images=[['https://s3.amazonaws.com/moonup/production/uploads/1663933284359-611f9702593efbee33a4f7c9.png'],
['https://s3.amazonaws.com/moonup/production/uploads/1663933284374-611f9702593efbee33a4f7c9.png'],
['https://s3.amazonaws.com/moonup/production/uploads/1663933284412-611f9702593efbee33a4f7c9.png']]

title = "Bean leaf disease classifier"
description = "Upload an image of a bean leaf to find out if it is diseased"
interface = gr.Interface(classify, gr.Image(shape=(200, 200)), 'label',
                         examples=sample_images, title=title, description=description)

#demo.launch()
interface.launch(debug=False)