Zero-Shot-Image / app.py
LukeStiN's picture
Valid app
037b127
import streamlit as st
from PIL import Image
import torch
from time import time
from transformers import CLIPProcessor, CLIPModel
def separator(chaine: str, sep: list[str] = [',', ';']) -> list[str]:
""" ''.split() augmented method """
elements = []
for s in sep :
chaine = chaine.replace(s, '\n')
for ligne in chaine.split("\n"):
elements.extend(ligne.split())
return elements
st.set_page_config('Zero Shot Image', '🏷️', 'wide')
st.title('🏷️ Zero Shot Image Classification', help='Zero-shot image classification is the task of classifying previously unseen classes during training of a model')
info = st.empty()
col1, col2, col3 = st.columns(3, gap='large')
# Model
MODELS = {
'Fast' : 'openai/clip-vit-base-patch32',
'Accurate' : 'openai/clip-vit-large-patch14'
}
model_type = col1.radio('Model', MODELS.keys(), horizontal=True)
# Input image
_uploaded_image = col1.file_uploader('Image', ['png', 'jpg', 'jpeg', 'webp'])
if _uploaded_image is not None :
image = Image.open(_uploaded_image)
col2.image(image)
# Classes
_help = 'Enter classes separated by commas, semicolons, tabs, or newlines'
_classes = col1.text_area('Classes', '', help=_help, placeholder=_help)
classes = [c.strip() for c in separator(_classes, [',', '\t', ';', '\n'])]
# Run
if col1.button('Classify', use_container_width=True, type='primary') :
if _uploaded_image is None :
info.warning('Upload an image to start')
elif len(classes) < 2 :
info.warning('Insert fiew classes in text area')
else :
# Load model
start = time()
model = CLIPModel.from_pretrained(MODELS[model_type])
processor = CLIPProcessor.from_pretrained(pretrained_model_name_or_path=MODELS[model_type])
# Process
inputs = processor(text=classes, images=image, return_tensors="pt", padding=True)
# Outputs
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
indice = torch.argmax(probs).item()
with col3 : # Progress bars
f'## {classes[indice]}'
result = probs.tolist()[0]
for i, c in enumerate(classes) :
col3.progress(result[i], f'`{int(result[i]*100)}%` - {c}')
f'{round(time()-start, 2)}s'