LukeStiN commited on
Commit
037b127
1 Parent(s): 8284812
Files changed (1) hide show
  1. app.py +55 -16
app.py CHANGED
@@ -1,35 +1,74 @@
1
  import streamlit as st
2
 
3
  from PIL import Image
4
- # import torch, os
 
5
 
6
  from transformers import CLIPProcessor, CLIPModel
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # Input image
9
- _uploaded_image = st.file_uploader('Image', ['png', 'jpg', 'jpeg', 'webp'])
10
  if _uploaded_image is not None :
11
  image = Image.open(_uploaded_image)
12
- st.image(image)
13
 
14
  # Classes
15
- _classes = st.text_input('Classes', '', help='Enter classes separated by commas', placeholder='Enter classes separated by commas', autocomplete='classes')
 
 
16
 
17
  # Run
18
- if st.button('Classify', use_container_width=True, type='primary') :
19
-
20
- if ',' in _classes and _uploaded_image is not None :
21
 
22
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
23
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
 
 
 
 
 
24
 
25
- classes = [c.strip() for c in _classes.split(',')]
26
  inputs = processor(text=classes, images=image, return_tensors="pt", padding=True)
27
 
 
28
  outputs = model(**inputs)
29
- logits_per_image = outputs.logits_per_image # this is the image-text similarity score
30
- probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
31
- # indice = torch.argmax().item()
32
-
33
- # classes[indice]
 
34
 
35
- max(probs)
 
 
 
 
 
1
  import streamlit as st
2
 
3
  from PIL import Image
4
+ import torch
5
+ from time import time
6
 
7
  from transformers import CLIPProcessor, CLIPModel
8
 
9
+ def separator(chaine: str, sep: list[str] = [',', ';']) -> list[str]:
10
+ """ ''.split() augmented method """
11
+ elements = []
12
+ for s in sep :
13
+ chaine = chaine.replace(s, '\n')
14
+
15
+ for ligne in chaine.split("\n"):
16
+ elements.extend(ligne.split())
17
+
18
+ return elements
19
+
20
+ st.set_page_config('Zero Shot Image', '🏷️', 'wide')
21
+ st.title('🏷️ Zero Shot Image Classification', help='Zero-shot image classification is the task of classifying previously unseen classes during training of a model')
22
+
23
+ info = st.empty()
24
+ col1, col2, col3 = st.columns(3, gap='large')
25
+
26
+ # Model
27
+ MODELS = {
28
+ 'Fast' : 'openai/clip-vit-base-patch32',
29
+ 'Accurate' : 'openai/clip-vit-large-patch14'
30
+ }
31
+ model_type = col1.radio('Model', MODELS.keys(), horizontal=True)
32
+
33
  # Input image
34
+ _uploaded_image = col1.file_uploader('Image', ['png', 'jpg', 'jpeg', 'webp'])
35
  if _uploaded_image is not None :
36
  image = Image.open(_uploaded_image)
37
+ col2.image(image)
38
 
39
  # Classes
40
+ _help = 'Enter classes separated by commas, semicolons, tabs, or newlines'
41
+ _classes = col1.text_area('Classes', '', help=_help, placeholder=_help)
42
+ classes = [c.strip() for c in separator(_classes, [',', '\t', ';', '\n'])]
43
 
44
  # Run
45
+ if col1.button('Classify', use_container_width=True, type='primary') :
46
+ if _uploaded_image is None :
47
+ info.warning('Upload an image to start')
48
 
49
+ elif len(classes) < 2 :
50
+ info.warning('Insert fiew classes in text area')
51
+
52
+ else :
53
+ # Load model
54
+ start = time()
55
+ model = CLIPModel.from_pretrained(MODELS[model_type])
56
+ processor = CLIPProcessor.from_pretrained(pretrained_model_name_or_path=MODELS[model_type])
57
 
58
+ # Process
59
  inputs = processor(text=classes, images=image, return_tensors="pt", padding=True)
60
 
61
+ # Outputs
62
  outputs = model(**inputs)
63
+ logits_per_image = outputs.logits_per_image
64
+ probs = logits_per_image.softmax(dim=1)
65
+ indice = torch.argmax(probs).item()
66
+
67
+ with col3 : # Progress bars
68
+ f'## {classes[indice]}'
69
 
70
+ result = probs.tolist()[0]
71
+ for i, c in enumerate(classes) :
72
+ col3.progress(result[i], f'`{int(result[i]*100)}%` - {c}')
73
+
74
+ f'{round(time()-start, 2)}s'