import cv2 import numpy as np import streamlit as st from PIL import Image from models.anime_face_detection_model import SingleShotDetectorModel from models.prototypical_networks import ( PrototypicalNetworksGradCAM, PrototypicalNetworksModel, ) from utils import configs from utils.functional import ( generate_empty_space, get_default_images, get_most_salient_object, set_page_config, set_seed, ) # Set seed set_seed() # Set page config set_page_config("HxH Character Anime Detection with Prototypical Networks", "πŸ”Ž") # Sidebar name_model = st.sidebar.selectbox("Select Model", tuple(configs.NAME_MODELS.keys())) support_set_method = st.sidebar.selectbox( "Select Support Set Method", configs.SUPPORT_SET_METHODS ) freeze_model = st.sidebar.checkbox("Freeze Model", value=True) pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True) # Load Model @st.cache_resource def load_model( name_model: str, support_set_method: str, freeze_model: bool, pretrained_model: bool ): prototypical_networks = PrototypicalNetworksModel( name_model, freeze_model, pretrained_model, support_set_method ) custom_grad_cam = PrototypicalNetworksGradCAM( name_model, freeze_model, pretrained_model, support_set_method ) ssd_model = SingleShotDetectorModel() return prototypical_networks, custom_grad_cam, ssd_model prototypical_networks, custom_grad_cam, ssd_model = load_model( name_model, support_set_method, freeze_model, pretrained_model ) # Application Description st.markdown("# ❓ Application Description") st.write( f""" Welcome to our HxH Character Anime Detection with Prototypical Networks application! πŸ•΅οΈβ€β™‚οΈπŸ¦Έβ€β™€οΈπŸ” This powerful and efficient tool allows you to quickly and accurately identify your favorite anime characters from Hunter x Hunter using state-of-the-art Prototypical Networks. Simply upload an image or select one of our default options, and let our model do the rest! With our user-friendly interface, anyone can easily classify HxH anime characters with just a few clicks. But that's not all! Our application also features a powerful Grad-CAM visualization tool that lets you see which parts of the image the model is using to make its predictions. Plus, with lightning-fast inference times, you won't have to wait long to get your results. Whether you're a hardcore anime fan or just looking for a fun way to pass the time, our HxH Character Anime Detection app is sure to entertain and delight. So what are you waiting for? Give it a try and see how many characters you can identify! DISCLAIMER: The output of this app only {", ".join(configs.CLASS_CHARACTERS)} """ ) uploaded_file = st.file_uploader( "Upload image file", type=["jpg", "jpeg", "png", "bmp", "tiff"] ) select_default_images = st.selectbox("Select default images", get_default_images()) st.caption("Default Images will be used if no image is uploaded.") select_image_button = st.button("Select Image") if select_image_button: st.success("Image selected") if select_image_button and uploaded_file is not None: image = np.array(Image.open(uploaded_file).convert("RGB")) st.session_state["image"] = image elif select_image_button and uploaded_file is None: image = np.array(Image.open(select_default_images).convert("RGB")) st.session_state["image"] = image if st.session_state.get("image") is not None: image = st.session_state.get("image") col1, col2, col3 = st.columns(3) col2.write("## πŸ“Έ Preview Image") col2.image(image, use_column_width=True) predict_image_button = col2.button("Detect Image Character") generate_empty_space(2) if predict_image_button: with st.spinner("Detecting Image Character..."): results_face_anime_detection = ssd_model.detect_anime_face(image) result_grad_cam = custom_grad_cam.get_grad_cam(image) bounding_box_image = image.copy() inference_time = results_face_anime_detection["inference_time"] results_anime_face = [] if results_face_anime_detection["anime_face"]: for result in results_face_anime_detection["anime_face"]: crop_image = image[ int(result[1]) : int(result[3]), int(result[0]) : int(result[2]) ] character = prototypical_networks.predict(crop_image) character_grad_cam = custom_grad_cam.get_grad_cam( crop_image, ) results_anime_face.append( { "face": crop_image, "face_grad_cam": character_grad_cam, "most_salient_object": get_most_salient_object(crop_image), "character": character["character"], "confidence_detection": result[4], "confidence_classification": character["confidence"], } ) inference_time += character["inference_time"] cv2.rectangle( bounding_box_image, (int(result[0]), int(result[1])), (int(result[2]), int(result[3])), (255, 255, 0), 4, ) cv2.putText( bounding_box_image, character["character"], (int(result[0]), int(result[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 0), 2, ) col1, col2, col3, col4 = st.columns(4) col1.write("### πŸ™‚ Source Image") col1.image(image, use_column_width=True) col2.write("### πŸ˜‰ Detected Image") col2.image(bounding_box_image, use_column_width=True) col3.write("### 😎 Grad CAM Image") col3.image(result_grad_cam, use_column_width=True) col4.write("### πŸ€” Most Salient Object") col4.image(get_most_salient_object(image), use_column_width=True) st.write("### πŸ“ Result") st.write(f"Inference Time: {inference_time:.2f} s") for result in results_anime_face: col1, col2, col3 = st.columns(3) col1.write("#### πŸ™‚ Cropped Face Image") col1.image(result["face"], use_column_width=True) col2.write("#### 😎 Cropped Face Grad CAM Image") col2.image(result["face_grad_cam"], use_column_width=True) col3.write("### πŸ€” Most Salient Object") col3.image( get_most_salient_object(result["most_salient_object"]), use_column_width=True, ) st.write(f"Character: {result['character'].title()}") st.write( f"Confidence Score Detection: {result['confidence_detection']*100:.2f}%" ) st.write( f"Confidence Score Classification: {result['confidence_classification']*100:.2f}%" ) generate_empty_space(2) st.session_state["image"] = None