Hunter-X-Hunter-Anime-Classification / pages /10-🎯 Zero-Shot Image Classification with CLIP.py
hafidhsoekma's picture
First commit
49bceed
import numpy as np
import streamlit as st
from PIL import Image
from streamlit_tags import st_tags
from models.misc import CLIPModel
from models.prototypical_networks import PrototypicalNetworksGradCAM
from utils import configs
from utils.functional import (
generate_empty_space,
get_default_images,
get_most_salient_object,
set_page_config,
set_seed,
)
# Set seed
set_seed()
# Set page config
set_page_config("Zero-Shot Image Classification with CLIP", "🎯")
# Sidebar
freeze_model = st.sidebar.checkbox("Freeze Model", value=True)
pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True)
# Load Model
@st.cache_resource
def load_model(
name_model: str,
support_set_method: str = "5_shot",
freeze_model: bool = True,
pretrained_model: bool = True,
):
clip_model = CLIPModel(
name_model, freeze_model, pretrained_model, support_set_method
)
custom_grad_cam = PrototypicalNetworksGradCAM(
"clip",
freeze_model,
pretrained_model,
support_set_method,
)
return clip_model, custom_grad_cam
clip_model, custom_grad_cam = load_model(
configs.CLIP_NAME_MODEL,
freeze_model=freeze_model,
pretrained_model=pretrained_model,
)
# Application Description
st.markdown("# ❓ Application Description")
st.write(
"""
Zero-Shot Image Classification with CLIP is an innovative application that allows users to classify images using natural language without the need for any training data. By leveraging state-of-the-art natural language processing and computer vision models, CLIP can understand the relationship between images and text, enabling it to accurately classify images based on their content.
This application is particularly useful in situations where large amounts of labeled data are not available or when there is a need to classify images based on non-traditional categories. Users can simply provide a description of the image in natural language and the CLIP model will classify it accordingly. Additionally, Zero-Shot Image Classification with CLIP can be used to perform tasks such as image retrieval and image captioning.
The application's ability to classify images without the need for traditional training data makes it an incredibly powerful tool in the field of computer vision. With the ability to classify images based on natural language descriptions, users can easily and quickly perform tasks that were previously impossible. With its innovative approach to image classification, Zero-Shot Image Classification with CLIP is poised to revolutionize the way we think about computer vision. πŸš€
"""
)
uploaded_file = st.file_uploader(
"Upload image file ", type=["jpg", "jpeg", "png", "bmp", "tiff"]
)
select_default_images = st.selectbox("Select default images ", get_default_images())
st.caption("Default Images will be used if no image is uploaded.")
select_image_button = st.button("Select Image")
if select_image_button:
st.success("Image selected")
generate_empty_space(2)
list_class = st_tags(
label="Input list of classes",
text="Press enter to add more",
)
select_default_class = st.selectbox(
"Select default classes ",
configs.LIST_DEFAULT_CLASSES_FOR_ZERO_SHOT.keys(),
)
st.caption("Default Class will be used if no class is inputted.")
select_class_button = st.button("Select Class")
if select_class_button:
st.success("Class selected")
if select_image_button and uploaded_file is not None:
image = np.array(Image.open(uploaded_file).convert("RGB"))
st.session_state["image"] = image
elif select_image_button and uploaded_file is None:
image = np.array(Image.open(select_default_images).convert("RGB"))
st.session_state["image"] = image
if select_class_button and len(list_class) > 0:
st.session_state["list_class"] = list_class
elif select_class_button and len(list_class) == 0:
st.session_state["list_class"] = configs.LIST_DEFAULT_CLASSES_FOR_ZERO_SHOT[
select_default_class
]
if (
st.session_state.get("image") is not None
and st.session_state.get("list_class") is not None
):
image = st.session_state.get("image")
list_class = tuple(st.session_state.get("list_class"))
col, col2, col3 = st.columns(3)
col2.write("## πŸ“Έ Preview Image")
col2.image(image, use_column_width=True)
predict_image_button = col2.button("Classify Image")
generate_empty_space(2)
if predict_image_button:
with st.spinner("Classifying Image..."):
result_class = clip_model.predict(image, list_class)
result_grad_cam = custom_grad_cam.get_grad_cam_with_output_target(
image, list_class.index(result_class["class"])
)
inference_time = result_class["inference_time"]
col1, col2, col3 = st.columns(3)
col1.write("### πŸ™‚ Source Image")
col1.image(image, use_column_width=True)
col2.write("### 😎 Grad CAM Image")
col2.image(result_grad_cam, use_column_width=True)
col3.write("### πŸ€” Most Salient Object")
col3.image(get_most_salient_object(image), use_column_width=True)
st.write("### πŸ“ Result")
st.write(f"Image Class: {result_class['class'].title()}")
st.write(f"Confidence Score: {result_class['confidence']* 100:.2f}%")
st.write(f"Inference Time: {inference_time:.2f} s")
st.session_state["image"] = None
st.session_state["list_class"] = None