Spaces:
Running
Running
Valid app
Browse files
app.py
CHANGED
@@ -1,35 +1,74 @@
|
|
1 |
import streamlit as st
|
2 |
|
3 |
from PIL import Image
|
4 |
-
|
|
|
5 |
|
6 |
from transformers import CLIPProcessor, CLIPModel
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
# Input image
|
9 |
-
_uploaded_image =
|
10 |
if _uploaded_image is not None :
|
11 |
image = Image.open(_uploaded_image)
|
12 |
-
|
13 |
|
14 |
# Classes
|
15 |
-
|
|
|
|
|
16 |
|
17 |
# Run
|
18 |
-
if
|
19 |
-
|
20 |
-
|
21 |
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
26 |
inputs = processor(text=classes, images=image, return_tensors="pt", padding=True)
|
27 |
|
|
|
28 |
outputs = model(**inputs)
|
29 |
-
logits_per_image = outputs.logits_per_image
|
30 |
-
probs = logits_per_image.softmax(dim=1)
|
31 |
-
|
32 |
-
|
33 |
-
#
|
|
|
34 |
|
35 |
-
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
|
3 |
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from time import time
|
6 |
|
7 |
from transformers import CLIPProcessor, CLIPModel
|
8 |
|
9 |
+
def separator(chaine: str, sep: list[str] = [',', ';']) -> list[str]:
|
10 |
+
""" ''.split() augmented method """
|
11 |
+
elements = []
|
12 |
+
for s in sep :
|
13 |
+
chaine = chaine.replace(s, '\n')
|
14 |
+
|
15 |
+
for ligne in chaine.split("\n"):
|
16 |
+
elements.extend(ligne.split())
|
17 |
+
|
18 |
+
return elements
|
19 |
+
|
20 |
+
st.set_page_config('Zero Shot Image', '🏷️', 'wide')
|
21 |
+
st.title('🏷️ Zero Shot Image Classification', help='Zero-shot image classification is the task of classifying previously unseen classes during training of a model')
|
22 |
+
|
23 |
+
info = st.empty()
|
24 |
+
col1, col2, col3 = st.columns(3, gap='large')
|
25 |
+
|
26 |
+
# Model
|
27 |
+
MODELS = {
|
28 |
+
'Fast' : 'openai/clip-vit-base-patch32',
|
29 |
+
'Accurate' : 'openai/clip-vit-large-patch14'
|
30 |
+
}
|
31 |
+
model_type = col1.radio('Model', MODELS.keys(), horizontal=True)
|
32 |
+
|
33 |
# Input image
|
34 |
+
_uploaded_image = col1.file_uploader('Image', ['png', 'jpg', 'jpeg', 'webp'])
|
35 |
if _uploaded_image is not None :
|
36 |
image = Image.open(_uploaded_image)
|
37 |
+
col2.image(image)
|
38 |
|
39 |
# Classes
|
40 |
+
_help = 'Enter classes separated by commas, semicolons, tabs, or newlines'
|
41 |
+
_classes = col1.text_area('Classes', '', help=_help, placeholder=_help)
|
42 |
+
classes = [c.strip() for c in separator(_classes, [',', '\t', ';', '\n'])]
|
43 |
|
44 |
# Run
|
45 |
+
if col1.button('Classify', use_container_width=True, type='primary') :
|
46 |
+
if _uploaded_image is None :
|
47 |
+
info.warning('Upload an image to start')
|
48 |
|
49 |
+
elif len(classes) < 2 :
|
50 |
+
info.warning('Insert fiew classes in text area')
|
51 |
+
|
52 |
+
else :
|
53 |
+
# Load model
|
54 |
+
start = time()
|
55 |
+
model = CLIPModel.from_pretrained(MODELS[model_type])
|
56 |
+
processor = CLIPProcessor.from_pretrained(pretrained_model_name_or_path=MODELS[model_type])
|
57 |
|
58 |
+
# Process
|
59 |
inputs = processor(text=classes, images=image, return_tensors="pt", padding=True)
|
60 |
|
61 |
+
# Outputs
|
62 |
outputs = model(**inputs)
|
63 |
+
logits_per_image = outputs.logits_per_image
|
64 |
+
probs = logits_per_image.softmax(dim=1)
|
65 |
+
indice = torch.argmax(probs).item()
|
66 |
+
|
67 |
+
with col3 : # Progress bars
|
68 |
+
f'## {classes[indice]}'
|
69 |
|
70 |
+
result = probs.tolist()[0]
|
71 |
+
for i, c in enumerate(classes) :
|
72 |
+
col3.progress(result[i], f'`{int(result[i]*100)}%` - {c}')
|
73 |
+
|
74 |
+
f'{round(time()-start, 2)}s'
|