import streamlit as st from PIL import Image import torch from time import time from transformers import CLIPProcessor, CLIPModel def separator(chaine: str, sep: list[str] = [',', ';']) -> list[str]: """ ''.split() augmented method """ elements = [] for s in sep : chaine = chaine.replace(s, '\n') for ligne in chaine.split("\n"): elements.extend(ligne.split()) return elements st.set_page_config('Zero Shot Image', '🏷️', 'wide') st.title('🏷️ Zero Shot Image Classification', help='Zero-shot image classification is the task of classifying previously unseen classes during training of a model') info = st.empty() col1, col2, col3 = st.columns(3, gap='large') # Model MODELS = { 'Fast' : 'openai/clip-vit-base-patch32', 'Accurate' : 'openai/clip-vit-large-patch14' } model_type = col1.radio('Model', MODELS.keys(), horizontal=True) # Input image _uploaded_image = col1.file_uploader('Image', ['png', 'jpg', 'jpeg', 'webp']) if _uploaded_image is not None : image = Image.open(_uploaded_image) col2.image(image) # Classes _help = 'Enter classes separated by commas, semicolons, tabs, or newlines' _classes = col1.text_area('Classes', '', help=_help, placeholder=_help) classes = [c.strip() for c in separator(_classes, [',', '\t', ';', '\n'])] # Run if col1.button('Classify', use_container_width=True, type='primary') : if _uploaded_image is None : info.warning('Upload an image to start') elif len(classes) < 2 : info.warning('Insert fiew classes in text area') else : # Load model start = time() model = CLIPModel.from_pretrained(MODELS[model_type]) processor = CLIPProcessor.from_pretrained(pretrained_model_name_or_path=MODELS[model_type]) # Process inputs = processor(text=classes, images=image, return_tensors="pt", padding=True) # Outputs outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) indice = torch.argmax(probs).item() with col3 : # Progress bars f'## {classes[indice]}' result = probs.tolist()[0] for i, c in enumerate(classes) : col3.progress(result[i], f'`{int(result[i]*100)}%` - {c}') f'{round(time()-start, 2)}s'