Hunter-X-Hunter-Anime-Classification / pages /02-🦸 HxH Character Anime Classification with Prototypical Networks.py
hafidhsoekma's picture
First commit
49bceed
import numpy as np
import streamlit as st
from PIL import Image
from models.prototypical_networks import (
PrototypicalNetworksGradCAM,
PrototypicalNetworksModel,
)
from utils import configs
from utils.functional import (
generate_empty_space,
get_default_images,
get_most_salient_object,
set_page_config,
set_seed,
)
# Set seed
set_seed()
# Set page config
set_page_config("HxH Character Anime Classification with Prototypical Networks", "🦸")
# Sidebar
name_model = st.sidebar.selectbox("Select Model", tuple(configs.NAME_MODELS.keys()))
support_set_method = st.sidebar.selectbox(
"Select Support Set Method", configs.SUPPORT_SET_METHODS
)
freeze_model = st.sidebar.checkbox("Freeze Model", value=True)
pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True)
# Load Model
@st.cache_resource
def load_model(
name_model: str, support_set_method: str, freeze_model: bool, pretrained_model: bool
):
prototypical_networks = PrototypicalNetworksModel(
name_model, freeze_model, pretrained_model, support_set_method
)
custom_grad_cam = PrototypicalNetworksGradCAM(
name_model, freeze_model, pretrained_model, support_set_method
)
return prototypical_networks, custom_grad_cam
prototypical_networks, custom_grad_cam = load_model(
name_model, support_set_method, freeze_model, pretrained_model
)
# Application Description
st.markdown("# ❓ Application Description")
st.write(
f"""
Welcome to our HxH Character Anime Classification with Prototypical Networks 🦸 app! With just a few clicks, you can classify your favorite anime characters from Hunter x Hunter using our powerful and efficient Prototypical Networks. Our user-friendly interface makes it easy for anyone to get started, whether you're a hardcore anime fan or just looking for a fun way to pass the time.
Simply upload an image or select one of our default images, and let our app do the rest! Our app will accurately identify and classify the character, and even provide you with a Grad-CAM image to show you which parts of the image contributed most to the classification.
So what are you waiting for? Try our HxH Character Anime Classification app now and see if you can correctly identify all your favorite characters!
DISCLAIMER: The output of this app only {", ".join(configs.CLASS_CHARACTERS)}
"""
)
uploaded_file = st.file_uploader(
"Upload image file", type=["jpg", "jpeg", "png", "bmp", "tiff"]
)
select_default_images = st.selectbox("Select default images", get_default_images())
st.caption("Default Images will be used if no image is uploaded.")
select_image_button = st.button("Select Image")
if select_image_button:
st.success("Image selected")
if select_image_button and uploaded_file is not None:
image = np.array(Image.open(uploaded_file).convert("RGB"))
st.session_state["image"] = image
elif select_image_button and uploaded_file is None:
image = np.array(Image.open(select_default_images).convert("RGB"))
st.session_state["image"] = image
if st.session_state.get("image") is not None:
image = st.session_state.get("image")
col1, col2, col3 = st.columns(3)
col2.write("## πŸ“Έ Preview Image")
col2.image(image, use_column_width=True)
predict_image_button = col2.button("Classify Image Character")
generate_empty_space(2)
if predict_image_button:
with st.spinner("Classifying Image Character..."):
result_class = prototypical_networks.predict(image)
result_grad_cam = custom_grad_cam.get_grad_cam(image)
inference_time = result_class["inference_time"]
col1, col2, col3 = st.columns(3)
col1.write("### πŸ™‚ Source Image")
col1.image(image, use_column_width=True)
col2.write("### 😎 Grad CAM Image")
col2.image(result_grad_cam, use_column_width=True)
col3.write("### πŸ€” Most Salient Object")
col3.image(get_most_salient_object(image), use_column_width=True)
st.write("### πŸ“ Result")
st.write(f"Predicted Character: {result_class['character'].title()}")
st.write(f"Confidence Score: {result_class['confidence'] * 100:.2f}%")
st.write(f"Inference Time: {inference_time:.2f} s")
st.session_state["image"] = None