Hunter-X-Hunter-Anime-Classification
/
pages
/05-π Image Embeddings with Prototypical Networks.py
import numpy as np | |
import streamlit as st | |
from PIL import Image | |
from models.prototypical_networks import ImageEmbeddings, PrototypicalNetworksGradCAM | |
from utils import configs | |
from utils.functional import ( | |
generate_empty_space, | |
get_default_images, | |
get_most_salient_object, | |
set_page_config, | |
set_seed, | |
) | |
# Set seed | |
set_seed() | |
# Set page config | |
set_page_config("Image Embeddings with Prototypical Networks", "π") | |
# Sidebar | |
name_model = st.sidebar.selectbox("Select Model", tuple(configs.NAME_MODELS.keys())) | |
support_set_method = st.sidebar.selectbox( | |
"Select Support Set Method", configs.SUPPORT_SET_METHODS | |
) | |
freeze_model = st.sidebar.checkbox("Freeze Model", value=True) | |
pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True) | |
# Load Model | |
def load_model( | |
name_model: str, support_set_method: str, freeze_model: bool, pretrained_model: bool | |
): | |
image_embeddings = ImageEmbeddings( | |
name_model, freeze_model, pretrained_model, support_set_method | |
) | |
custom_grad_cam = PrototypicalNetworksGradCAM( | |
name_model, freeze_model, pretrained_model, support_set_method | |
) | |
return image_embeddings, custom_grad_cam | |
image_embeddings, custom_grad_cam = load_model( | |
name_model, support_set_method, freeze_model, pretrained_model | |
) | |
# Application Description | |
st.markdown("# β Application Description") | |
st.write( | |
""" | |
Introducing Image Embeddings with Prototypical Networks π, an innovative app that lets you extract rich and meaningful representations of your images using cutting-edge deep learning techniques. With our app, you can easily generate image embeddings that capture the essence of your visual data, allowing you to analyze and compare images in ways never before possible. | |
Whether you're a data scientist, a machine learning engineer, or just someone who loves working with visual data, our app is the perfect tool for you. With a user-friendly interface and intuitive features, you'll be generating high-quality image embeddings in no time. | |
Our app is also highly versatile, allowing you to use image embeddings for a wide range of applications, including image search, content-based image retrieval, and more. And with our Prototypical Networks model, you can rest assured that you're getting state-of-the-art performance and accuracy. | |
So why wait? Try out Image Embeddings with Prototypical Networks π today and unlock the full potential of your visual data! | |
""" | |
) | |
uploaded_file = st.file_uploader( | |
"Upload image file ", type=["jpg", "jpeg", "png", "bmp", "tiff"] | |
) | |
select_default_images = st.selectbox("Select default images ", get_default_images()) | |
st.caption("Default Images will be used if no image is uploaded.") | |
select_image_button = st.button("Select Image") | |
if select_image_button: | |
st.success("Image selected") | |
if select_image_button and uploaded_file is not None: | |
image = np.array(Image.open(uploaded_file).convert("RGB")) | |
st.session_state["image"] = image | |
elif select_image_button and uploaded_file is None: | |
image = np.array(Image.open(select_default_images).convert("RGB")) | |
st.session_state["image"] = image | |
if st.session_state.get("image") is not None: | |
image = st.session_state.get("image") | |
col, col2, col3 = st.columns(3) | |
col2.write("## πΈ Preview Image ") | |
col2.image(image, use_column_width=True) | |
predict_image_button = col2.button("Get Image Embeddings") | |
generate_empty_space(2) | |
if predict_image_button: | |
with st.spinner("Getting Image Embeddings..."): | |
result_embeddings = image_embeddings.get_embeddings(image) | |
result_grad_cam = custom_grad_cam.get_grad_cam(image) | |
inference_time = result_embeddings["inference_time"] | |
col1, col2, col3 = st.columns(3) | |
col1.write("### π Source Image") | |
col1.image(image, use_column_width=True) | |
col2.write("### π Grad CAM Image") | |
col2.image(result_grad_cam, use_column_width=True) | |
col3.write("### π€ Most Salient Object") | |
col3.image(get_most_salient_object(image), use_column_width=True) | |
st.write("### π Result") | |
st.json( | |
{ | |
"image_embeddings": result_embeddings["embeddings"].tolist(), | |
} | |
) | |
st.write(f"Image Embeddings Shape: {result_embeddings['embeddings'].shape}") | |
st.write(f"Inference Time: {inference_time:.2f} s") | |
st.session_state["image"] = None | |