Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
from time import time | |
from transformers import CLIPProcessor, CLIPModel | |
def separator(chaine: str, sep: list[str] = [',', ';']) -> list[str]: | |
""" ''.split() augmented method """ | |
elements = [] | |
for s in sep : | |
chaine = chaine.replace(s, '\n') | |
for ligne in chaine.split("\n"): | |
elements.extend(ligne.split()) | |
return elements | |
st.set_page_config('Zero Shot Image', '🏷️', 'wide') | |
st.title('🏷️ Zero Shot Image Classification', help='Zero-shot image classification is the task of classifying previously unseen classes during training of a model') | |
info = st.empty() | |
col1, col2, col3 = st.columns(3, gap='large') | |
# Model | |
MODELS = { | |
'Fast' : 'openai/clip-vit-base-patch32', | |
'Accurate' : 'openai/clip-vit-large-patch14' | |
} | |
model_type = col1.radio('Model', MODELS.keys(), horizontal=True) | |
# Input image | |
_uploaded_image = col1.file_uploader('Image', ['png', 'jpg', 'jpeg', 'webp']) | |
if _uploaded_image is not None : | |
image = Image.open(_uploaded_image) | |
col2.image(image) | |
# Classes | |
_help = 'Enter classes separated by commas, semicolons, tabs, or newlines' | |
_classes = col1.text_area('Classes', '', help=_help, placeholder=_help) | |
classes = [c.strip() for c in separator(_classes, [',', '\t', ';', '\n'])] | |
# Run | |
if col1.button('Classify', use_container_width=True, type='primary') : | |
if _uploaded_image is None : | |
info.warning('Upload an image to start') | |
elif len(classes) < 2 : | |
info.warning('Insert fiew classes in text area') | |
else : | |
# Load model | |
start = time() | |
model = CLIPModel.from_pretrained(MODELS[model_type]) | |
processor = CLIPProcessor.from_pretrained(pretrained_model_name_or_path=MODELS[model_type]) | |
# Process | |
inputs = processor(text=classes, images=image, return_tensors="pt", padding=True) | |
# Outputs | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
probs = logits_per_image.softmax(dim=1) | |
indice = torch.argmax(probs).item() | |
with col3 : # Progress bars | |
f'## {classes[indice]}' | |
result = probs.tolist()[0] | |
for i, c in enumerate(classes) : | |
col3.progress(result[i], f'`{int(result[i]*100)}%` - {c}') | |
f'{round(time()-start, 2)}s' |