import numpy as np import streamlit as st from PIL import Image from streamlit_tags import st_tags from models.misc import CLIPModel from models.prototypical_networks import PrototypicalNetworksGradCAM from utils import configs from utils.functional import ( generate_empty_space, get_default_images, get_most_salient_object, set_page_config, set_seed, ) # Set seed set_seed() # Set page config set_page_config("Zero-Shot Image Classification with CLIP", "🎯") # Sidebar freeze_model = st.sidebar.checkbox("Freeze Model", value=True) pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True) # Load Model @st.cache_resource def load_model( name_model: str, support_set_method: str = "5_shot", freeze_model: bool = True, pretrained_model: bool = True, ): clip_model = CLIPModel( name_model, freeze_model, pretrained_model, support_set_method ) custom_grad_cam = PrototypicalNetworksGradCAM( "clip", freeze_model, pretrained_model, support_set_method, ) return clip_model, custom_grad_cam clip_model, custom_grad_cam = load_model( configs.CLIP_NAME_MODEL, freeze_model=freeze_model, pretrained_model=pretrained_model, ) # Application Description st.markdown("# ❓ Application Description") st.write( """ Zero-Shot Image Classification with CLIP is an innovative application that allows users to classify images using natural language without the need for any training data. By leveraging state-of-the-art natural language processing and computer vision models, CLIP can understand the relationship between images and text, enabling it to accurately classify images based on their content. This application is particularly useful in situations where large amounts of labeled data are not available or when there is a need to classify images based on non-traditional categories. Users can simply provide a description of the image in natural language and the CLIP model will classify it accordingly. Additionally, Zero-Shot Image Classification with CLIP can be used to perform tasks such as image retrieval and image captioning. The application's ability to classify images without the need for traditional training data makes it an incredibly powerful tool in the field of computer vision. With the ability to classify images based on natural language descriptions, users can easily and quickly perform tasks that were previously impossible. With its innovative approach to image classification, Zero-Shot Image Classification with CLIP is poised to revolutionize the way we think about computer vision. 🚀 """ ) uploaded_file = st.file_uploader( "Upload image file ", type=["jpg", "jpeg", "png", "bmp", "tiff"] ) select_default_images = st.selectbox("Select default images ", get_default_images()) st.caption("Default Images will be used if no image is uploaded.") select_image_button = st.button("Select Image") if select_image_button: st.success("Image selected") generate_empty_space(2) list_class = st_tags( label="Input list of classes", text="Press enter to add more", ) select_default_class = st.selectbox( "Select default classes ", configs.LIST_DEFAULT_CLASSES_FOR_ZERO_SHOT.keys(), ) st.caption("Default Class will be used if no class is inputted.") select_class_button = st.button("Select Class") if select_class_button: st.success("Class selected") if select_image_button and uploaded_file is not None: image = np.array(Image.open(uploaded_file).convert("RGB")) st.session_state["image"] = image elif select_image_button and uploaded_file is None: image = np.array(Image.open(select_default_images).convert("RGB")) st.session_state["image"] = image if select_class_button and len(list_class) > 0: st.session_state["list_class"] = list_class elif select_class_button and len(list_class) == 0: st.session_state["list_class"] = configs.LIST_DEFAULT_CLASSES_FOR_ZERO_SHOT[ select_default_class ] if ( st.session_state.get("image") is not None and st.session_state.get("list_class") is not None ): image = st.session_state.get("image") list_class = tuple(st.session_state.get("list_class")) col, col2, col3 = st.columns(3) col2.write("## 📸 Preview Image") col2.image(image, use_column_width=True) predict_image_button = col2.button("Classify Image") generate_empty_space(2) if predict_image_button: with st.spinner("Classifying Image..."): result_class = clip_model.predict(image, list_class) result_grad_cam = custom_grad_cam.get_grad_cam_with_output_target( image, list_class.index(result_class["class"]) ) inference_time = result_class["inference_time"] col1, col2, col3 = st.columns(3) col1.write("### 🙂 Source Image") col1.image(image, use_column_width=True) col2.write("### 😎 Grad CAM Image") col2.image(result_grad_cam, use_column_width=True) col3.write("### 🤔 Most Salient Object") col3.image(get_most_salient_object(image), use_column_width=True) st.write("### 📝 Result") st.write(f"Image Class: {result_class['class'].title()}") st.write(f"Confidence Score: {result_class['confidence']* 100:.2f}%") st.write(f"Inference Time: {inference_time:.2f} s") st.session_state["image"] = None st.session_state["list_class"] = None