rsadaphule commited on
Commit
eac6703
1 Parent(s): 667ae74

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
+ import gradio as gr
4
+
5
+ model_id = f'rsadaphule/vit-base-patch16-224-finetuned-wildcats'
6
+ labels = ['AFRICAN LEOPARD',
7
+ 'CARACAL',
8
+ 'CHEETAH',
9
+ 'CLOUDED LEOPARD',
10
+ 'JAGUAR',
11
+ 'LIONS',
12
+ 'OCELOT',
13
+ 'PUMA',
14
+ 'SNOW LEOPARD',
15
+ 'TIGER']
16
+
17
+
18
+ def classify_image(image):
19
+ model = AutoModelForImageClassification.from_pretrained(model_id)
20
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
21
+ inp = feature_extractor(image, return_tensors='pt')
22
+ outp = model(**inp)
23
+ pred = torch.nn.functional.softmax(outp.logits, dim=-1)
24
+ preds = pred[0].cpu().detach().numpy()
25
+ confidence = {label: float(preds[i]) for i, label in enumerate(labels)}
26
+ return confidence
27
+
28
+ interface = gr.Interface(fn=classify_image,
29
+ inputs='image',
30
+ examples=['cat1.jpg', 'cat2.jpg'],
31
+ outputs='label').launch(debug=True, share=True)