Hunter-X-Hunter-Anime-Classification / pages /03-πŸ”Ž HxH Character Anime Detection with Prototypical Networks.py
hafidhsoekma's picture
First commit
49bceed
import cv2
import numpy as np
import streamlit as st
from PIL import Image
from models.anime_face_detection_model import SingleShotDetectorModel
from models.prototypical_networks import (
PrototypicalNetworksGradCAM,
PrototypicalNetworksModel,
)
from utils import configs
from utils.functional import (
generate_empty_space,
get_default_images,
get_most_salient_object,
set_page_config,
set_seed,
)
# Set seed
set_seed()
# Set page config
set_page_config("HxH Character Anime Detection with Prototypical Networks", "πŸ”Ž")
# Sidebar
name_model = st.sidebar.selectbox("Select Model", tuple(configs.NAME_MODELS.keys()))
support_set_method = st.sidebar.selectbox(
"Select Support Set Method", configs.SUPPORT_SET_METHODS
)
freeze_model = st.sidebar.checkbox("Freeze Model", value=True)
pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True)
# Load Model
@st.cache_resource
def load_model(
name_model: str, support_set_method: str, freeze_model: bool, pretrained_model: bool
):
prototypical_networks = PrototypicalNetworksModel(
name_model, freeze_model, pretrained_model, support_set_method
)
custom_grad_cam = PrototypicalNetworksGradCAM(
name_model, freeze_model, pretrained_model, support_set_method
)
ssd_model = SingleShotDetectorModel()
return prototypical_networks, custom_grad_cam, ssd_model
prototypical_networks, custom_grad_cam, ssd_model = load_model(
name_model, support_set_method, freeze_model, pretrained_model
)
# Application Description
st.markdown("# ❓ Application Description")
st.write(
f"""
Welcome to our HxH Character Anime Detection with Prototypical Networks application! πŸ•΅οΈβ€β™‚οΈπŸ¦Έβ€β™€οΈπŸ”
This powerful and efficient tool allows you to quickly and accurately identify your favorite anime characters from Hunter x Hunter using state-of-the-art Prototypical Networks. Simply upload an image or select one of our default options, and let our model do the rest! With our user-friendly interface, anyone can easily classify HxH anime characters with just a few clicks.
But that's not all! Our application also features a powerful Grad-CAM visualization tool that lets you see which parts of the image the model is using to make its predictions. Plus, with lightning-fast inference times, you won't have to wait long to get your results.
Whether you're a hardcore anime fan or just looking for a fun way to pass the time, our HxH Character Anime Detection app is sure to entertain and delight. So what are you waiting for? Give it a try and see how many characters you can identify!
DISCLAIMER: The output of this app only {", ".join(configs.CLASS_CHARACTERS)}
"""
)
uploaded_file = st.file_uploader(
"Upload image file", type=["jpg", "jpeg", "png", "bmp", "tiff"]
)
select_default_images = st.selectbox("Select default images", get_default_images())
st.caption("Default Images will be used if no image is uploaded.")
select_image_button = st.button("Select Image")
if select_image_button:
st.success("Image selected")
if select_image_button and uploaded_file is not None:
image = np.array(Image.open(uploaded_file).convert("RGB"))
st.session_state["image"] = image
elif select_image_button and uploaded_file is None:
image = np.array(Image.open(select_default_images).convert("RGB"))
st.session_state["image"] = image
if st.session_state.get("image") is not None:
image = st.session_state.get("image")
col1, col2, col3 = st.columns(3)
col2.write("## πŸ“Έ Preview Image")
col2.image(image, use_column_width=True)
predict_image_button = col2.button("Detect Image Character")
generate_empty_space(2)
if predict_image_button:
with st.spinner("Detecting Image Character..."):
results_face_anime_detection = ssd_model.detect_anime_face(image)
result_grad_cam = custom_grad_cam.get_grad_cam(image)
bounding_box_image = image.copy()
inference_time = results_face_anime_detection["inference_time"]
results_anime_face = []
if results_face_anime_detection["anime_face"]:
for result in results_face_anime_detection["anime_face"]:
crop_image = image[
int(result[1]) : int(result[3]), int(result[0]) : int(result[2])
]
character = prototypical_networks.predict(crop_image)
character_grad_cam = custom_grad_cam.get_grad_cam(
crop_image,
)
results_anime_face.append(
{
"face": crop_image,
"face_grad_cam": character_grad_cam,
"most_salient_object": get_most_salient_object(crop_image),
"character": character["character"],
"confidence_detection": result[4],
"confidence_classification": character["confidence"],
}
)
inference_time += character["inference_time"]
cv2.rectangle(
bounding_box_image,
(int(result[0]), int(result[1])),
(int(result[2]), int(result[3])),
(255, 255, 0),
4,
)
cv2.putText(
bounding_box_image,
character["character"],
(int(result[0]), int(result[1]) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 255, 0),
2,
)
col1, col2, col3, col4 = st.columns(4)
col1.write("### πŸ™‚ Source Image")
col1.image(image, use_column_width=True)
col2.write("### πŸ˜‰ Detected Image")
col2.image(bounding_box_image, use_column_width=True)
col3.write("### 😎 Grad CAM Image")
col3.image(result_grad_cam, use_column_width=True)
col4.write("### πŸ€” Most Salient Object")
col4.image(get_most_salient_object(image), use_column_width=True)
st.write("### πŸ“ Result")
st.write(f"Inference Time: {inference_time:.2f} s")
for result in results_anime_face:
col1, col2, col3 = st.columns(3)
col1.write("#### πŸ™‚ Cropped Face Image")
col1.image(result["face"], use_column_width=True)
col2.write("#### 😎 Cropped Face Grad CAM Image")
col2.image(result["face_grad_cam"], use_column_width=True)
col3.write("### πŸ€” Most Salient Object")
col3.image(
get_most_salient_object(result["most_salient_object"]),
use_column_width=True,
)
st.write(f"Character: {result['character'].title()}")
st.write(
f"Confidence Score Detection: {result['confidence_detection']*100:.2f}%"
)
st.write(
f"Confidence Score Classification: {result['confidence_classification']*100:.2f}%"
)
generate_empty_space(2)
st.session_state["image"] = None