File size: 5,670 Bytes
49bceed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import numpy as np
import streamlit as st
from PIL import Image
from streamlit_tags import st_tags

from models.misc import CLIPModel
from models.prototypical_networks import PrototypicalNetworksGradCAM
from utils import configs
from utils.functional import (
    generate_empty_space,
    get_default_images,
    get_most_salient_object,
    set_page_config,
    set_seed,
)

# Set seed
set_seed()

# Set page config
set_page_config("Zero-Shot Image Classification with CLIP", "🎯")
# Sidebar
freeze_model = st.sidebar.checkbox("Freeze Model", value=True)
pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True)


# Load Model
@st.cache_resource
def load_model(

    name_model: str,

    support_set_method: str = "5_shot",

    freeze_model: bool = True,

    pretrained_model: bool = True,

):
    clip_model = CLIPModel(
        name_model, freeze_model, pretrained_model, support_set_method
    )
    custom_grad_cam = PrototypicalNetworksGradCAM(
        "clip",
        freeze_model,
        pretrained_model,
        support_set_method,
    )
    return clip_model, custom_grad_cam


clip_model, custom_grad_cam = load_model(
    configs.CLIP_NAME_MODEL,
    freeze_model=freeze_model,
    pretrained_model=pretrained_model,
)

# Application Description
st.markdown("# ❓ Application Description")
st.write(
    """

    Zero-Shot Image Classification with CLIP is an innovative application that allows users to classify images using natural language without the need for any training data. By leveraging state-of-the-art natural language processing and computer vision models, CLIP can understand the relationship between images and text, enabling it to accurately classify images based on their content.



This application is particularly useful in situations where large amounts of labeled data are not available or when there is a need to classify images based on non-traditional categories. Users can simply provide a description of the image in natural language and the CLIP model will classify it accordingly. Additionally, Zero-Shot Image Classification with CLIP can be used to perform tasks such as image retrieval and image captioning.



The application's ability to classify images without the need for traditional training data makes it an incredibly powerful tool in the field of computer vision. With the ability to classify images based on natural language descriptions, users can easily and quickly perform tasks that were previously impossible. With its innovative approach to image classification, Zero-Shot Image Classification with CLIP is poised to revolutionize the way we think about computer vision. πŸš€

    """
)

uploaded_file = st.file_uploader(
    "Upload image file ", type=["jpg", "jpeg", "png", "bmp", "tiff"]
)
select_default_images = st.selectbox("Select default images ", get_default_images())
st.caption("Default Images will be used if no image is uploaded.")
select_image_button = st.button("Select Image")
if select_image_button:
    st.success("Image selected")

generate_empty_space(2)
list_class = st_tags(
    label="Input list of classes",
    text="Press enter to add more",
)
select_default_class = st.selectbox(
    "Select default classes ",
    configs.LIST_DEFAULT_CLASSES_FOR_ZERO_SHOT.keys(),
)
st.caption("Default Class will be used if no class is inputted.")
select_class_button = st.button("Select Class")
if select_class_button:
    st.success("Class selected")

if select_image_button and uploaded_file is not None:
    image = np.array(Image.open(uploaded_file).convert("RGB"))
    st.session_state["image"] = image
elif select_image_button and uploaded_file is None:
    image = np.array(Image.open(select_default_images).convert("RGB"))
    st.session_state["image"] = image

if select_class_button and len(list_class) > 0:
    st.session_state["list_class"] = list_class
elif select_class_button and len(list_class) == 0:
    st.session_state["list_class"] = configs.LIST_DEFAULT_CLASSES_FOR_ZERO_SHOT[
        select_default_class
    ]

if (
    st.session_state.get("image") is not None
    and st.session_state.get("list_class") is not None
):
    image = st.session_state.get("image")
    list_class = tuple(st.session_state.get("list_class"))
    col, col2, col3 = st.columns(3)
    col2.write("## πŸ“Έ Preview Image")
    col2.image(image, use_column_width=True)
    predict_image_button = col2.button("Classify Image")
    generate_empty_space(2)
    if predict_image_button:
        with st.spinner("Classifying Image..."):
            result_class = clip_model.predict(image, list_class)
            result_grad_cam = custom_grad_cam.get_grad_cam_with_output_target(
                image, list_class.index(result_class["class"])
            )
            inference_time = result_class["inference_time"]
            col1, col2, col3 = st.columns(3)
            col1.write("### πŸ™‚ Source Image")
            col1.image(image, use_column_width=True)
            col2.write("### 😎 Grad CAM Image")
            col2.image(result_grad_cam, use_column_width=True)
            col3.write("### πŸ€” Most Salient Object")
            col3.image(get_most_salient_object(image), use_column_width=True)
            st.write("### πŸ“ Result")
            st.write(f"Image Class: {result_class['class'].title()}")
            st.write(f"Confidence Score: {result_class['confidence']* 100:.2f}%")
            st.write(f"Inference Time: {inference_time:.2f} s")
            st.session_state["image"] = None
            st.session_state["list_class"] = None