Hunter-X-Hunter-Anime-Classification
/
pages
/07-π· HxH Character Anime Detection with Deep Learning.py
import cv2 | |
import numpy as np | |
import streamlit as st | |
from PIL import Image | |
from models.anime_face_detection_model import SingleShotDetectorModel | |
from models.deep_learning import DeepLearningGradCAM, DeepLearningModel | |
from utils import configs | |
from utils.functional import ( | |
generate_empty_space, | |
get_default_images, | |
get_most_salient_object, | |
set_page_config, | |
set_seed, | |
) | |
# Set seed | |
set_seed() | |
# Set page config | |
set_page_config("HxH Character Anime Detection with Deep Learning", "π·") | |
# Sidebar | |
name_model = st.sidebar.selectbox("Select Model", tuple(configs.NAME_MODELS.keys())) | |
support_set_method = st.sidebar.selectbox( | |
"Select Support Set Method", configs.SUPPORT_SET_METHODS | |
) | |
freeze_model = st.sidebar.checkbox("Freeze Model", value=True) | |
pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True) | |
# Load Model | |
def load_model( | |
name_model: str, support_set_method: str, freeze_model: bool, pretrained_model: bool | |
): | |
deep_learning_model = DeepLearningModel( | |
name_model, freeze_model, pretrained_model, support_set_method | |
) | |
custom_grad_cam = DeepLearningGradCAM( | |
name_model, freeze_model, pretrained_model, support_set_method | |
) | |
ssd_model = SingleShotDetectorModel() | |
return deep_learning_model, custom_grad_cam, ssd_model | |
deep_learning_model, custom_grad_cam, ssd_model = load_model( | |
name_model, support_set_method, freeze_model, pretrained_model | |
) | |
# Application Description | |
st.markdown("# β Application Description") | |
st.write( | |
f""" | |
Introducing our HxH Character Anime Detection with Deep Learning app π·! This app is designed to help you easily and accurately detect and identify your favorite HxH anime characters using the power of deep learning. With just a click, you can upload an image and our model will quickly analyze it to determine which character it belongs to. | |
Our user-friendly interface makes it easy for anyone to get started, even if you have little to no experience with deep learning. Plus, our model has been trained on a large dataset of HxH anime characters, so you can trust its accuracy. | |
Whether you're a die-hard HxH fan or just looking for a fun and easy way to identify your favorite characters, our app is sure to impress. So why wait? Try out our HxH Character Anime Detection with Deep Learning app today and take your anime fandom to the next level! π | |
DISCLAIMER: The output of this app only {", ".join(configs.CLASS_CHARACTERS)} | |
""" | |
) | |
uploaded_file = st.file_uploader( | |
"Upload image file", type=["jpg", "jpeg", "png", "bmp", "tiff"] | |
) | |
select_default_images = st.selectbox("Select default images", get_default_images()) | |
st.caption("Default Images will be used if no image is uploaded.") | |
select_image_button = st.button("Select Image") | |
if select_image_button: | |
st.success("Image selected") | |
if select_image_button and uploaded_file is not None: | |
image = np.array(Image.open(uploaded_file).convert("RGB")) | |
st.session_state["image"] = image | |
elif select_image_button and uploaded_file is None: | |
image = np.array(Image.open(select_default_images).convert("RGB")) | |
st.session_state["image"] = image | |
if st.session_state.get("image") is not None: | |
image = st.session_state.get("image") | |
col1, col2, col3 = st.columns(3) | |
col2.write("## πΈ Preview Image") | |
col2.image(image, use_column_width=True) | |
predict_image_button = col2.button("Detect Image Character") | |
generate_empty_space(2) | |
if predict_image_button: | |
with st.spinner("Detecting Image Character..."): | |
results_face_anime_detection = ssd_model.detect_anime_face(image) | |
result_grad_cam = custom_grad_cam.get_grad_cam(image) | |
bounding_box_image = image.copy() | |
inference_time = results_face_anime_detection["inference_time"] | |
results_anime_face = [] | |
if results_face_anime_detection["anime_face"]: | |
for result in results_face_anime_detection["anime_face"]: | |
crop_image = image[ | |
int(result[1]) : int(result[3]), int(result[0]) : int(result[2]) | |
] | |
character = deep_learning_model.predict(crop_image) | |
if character["character"] == configs.CLASS_CHARACTERS[-1]: | |
character_grad_cam = custom_grad_cam.get_grad_cam(crop_image) | |
else: | |
character_grad_cam = custom_grad_cam.get_grad_cam_with_output_target( | |
crop_image, configs.CLASS_CHARACTERS.index(character["character"]) | |
) | |
results_anime_face.append( | |
{ | |
"face": crop_image, | |
"face_grad_cam": character_grad_cam, | |
"most_salient_object": get_most_salient_object(crop_image), | |
"character": character["character"], | |
"confidence_detection": result[4], | |
"confidence_classification": character["confidence"], | |
} | |
) | |
inference_time += character["inference_time"] | |
cv2.rectangle( | |
bounding_box_image, | |
(int(result[0]), int(result[1])), | |
(int(result[2]), int(result[3])), | |
(255, 255, 0), | |
4, | |
) | |
cv2.putText( | |
bounding_box_image, | |
character["character"], | |
(int(result[0]), int(result[1]) - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
1, | |
(255, 255, 0), | |
2, | |
) | |
col1, col2, col3, col4 = st.columns(4) | |
col1.write("### π Source Image") | |
col1.image(image, use_column_width=True) | |
col2.write("### π Detected Image") | |
col2.image(bounding_box_image, use_column_width=True) | |
col3.write("### π Grad CAM Image") | |
col3.image(result_grad_cam, use_column_width=True) | |
col4.write("### π€ Most Salient Object") | |
col4.image(get_most_salient_object(image), use_column_width=True) | |
st.write("### π Result") | |
st.write(f"Inference Time: {inference_time:.2f} s") | |
for result in results_anime_face: | |
col1, col2, col3 = st.columns(3) | |
col1.write("#### π Cropped Face Image") | |
col1.image(result["face"], use_column_width=True) | |
col2.write("#### π Cropped Face Grad CAM Image") | |
col2.image(result["face_grad_cam"], use_column_width=True) | |
col3.write("### π€ Most Salient Object") | |
col3.image(result["most_salient_object"], use_column_width=True) | |
st.write(f"Character: {result['character'].title()}") | |
st.write( | |
f"Confidence Score Detection: {result['confidence_detection']*100:.2f}%" | |
) | |
st.write( | |
f"Confidence Score Classification: {result['confidence_classification']*100:.2f}%" | |
) | |
generate_empty_space(2) | |
st.session_state["image"] = None | |