amanneo commited on
Commit
6f474e3
1 Parent(s): 72c250b

Updated app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -0
app.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
+ import streamlit as st
4
+ from PIL import Image
5
+
6
+ model_id = f'amanneo/vit-base-patch16-224-finetuned-flower'
7
+ labels = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
8
+
9
+ def classify_image(image):
10
+ model = AutoModelForImageClassification.from_pretrained(model_id)
11
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)
12
+ inp = feature_extractor(image, return_tensors='pt')
13
+ outp = model(**inp)
14
+ pred = torch.nn.functional.softmax(outp.logits, dim=-1)
15
+ preds = pred[0].cpu().detach().numpy()
16
+ confidence = {label: float(preds[i]) for i, label in enumerate(labels)}
17
+ return confidence
18
+
19
+ file_name = st.file_uploader("Upload flower image")
20
+ if file_name is not None:
21
+ col1,col2 = st.columns(2)
22
+ image = Image.open(file_name)
23
+ col1.image(image,use_column_width=True)
24
+ predictions = classify_image(image)
25
+ col2.header("Probabilities")
26
+ for l,p in predictions.items():
27
+ col2.subheader("{} : {}".format(l,p))