schrilax commited on
Commit
e9572e7
1 Parent(s): df8324b

add implementation/interface

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. app.py +9 -8
.DS_Store ADDED
Binary file (6.15 kB). View file
app.py CHANGED
@@ -2,21 +2,22 @@ import datasets
2
  import gradio as gr
3
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
4
 
5
- def classify(im):
6
- features = feature_extractor(im, return_tensors='pt')
7
- logits = model(features['pixel_values'])[-1]
8
- probability = torch.nn.functional.softmax(logits, dim=-1)
9
- probs = probability[0].detach().numpy()
10
- confidences = {label: float(probs[i]) for i, label in enumerate(labels)}
11
- return confidences
12
-
13
  dataset = datasets.load_dataset('beans', 'full_size')
14
 
 
15
  extractor = AutoFeatureExtractor.from_pretrained('saved_model_files')
16
  model = AutoModelForImageClassification.from_pretrained('saved_model_files')
17
 
18
  labels = dataset['train'].features['labels'].names
19
 
 
 
 
 
 
 
 
 
20
  interface = gr.Interface(fn=classify, inputs=gr.Image(shape=(200, 200)), outputs=gr.outputs.Label(num_top_classes=3),
21
  examples=['leaf1.png', 'leaf2.png'], title='Leaf Classification App', description='Check if the leaves of your plant are healthy!', flagging_dir='flagged_examples/')
22
 
2
  import gradio as gr
3
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
4
 
 
 
 
 
 
 
 
 
5
  dataset = datasets.load_dataset('beans', 'full_size')
6
 
7
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
8
  extractor = AutoFeatureExtractor.from_pretrained('saved_model_files')
9
  model = AutoModelForImageClassification.from_pretrained('saved_model_files')
10
 
11
  labels = dataset['train'].features['labels'].names
12
 
13
+ def classify(im):
14
+ features = extractor(im, return_tensors='pt')
15
+ logits = model(features['pixel_values'])[-1]
16
+ probability = torch.nn.functional.softmax(logits, dim=-1)
17
+ probs = probability[0].detach().numpy()
18
+ confidences = {label: float(probs[i]) for i, label in enumerate(labels)}
19
+ return confidences
20
+
21
  interface = gr.Interface(fn=classify, inputs=gr.Image(shape=(200, 200)), outputs=gr.outputs.Label(num_top_classes=3),
22
  examples=['leaf1.png', 'leaf2.png'], title='Leaf Classification App', description='Check if the leaves of your plant are healthy!', flagging_dir='flagged_examples/')
23