import numpy as np import streamlit as st from PIL import Image from models.prototypical_networks import ( PrototypicalNetworksGradCAM, PrototypicalNetworksModel, ) from utils import configs from utils.functional import ( generate_empty_space, get_default_images, get_most_salient_object, set_page_config, set_seed, ) # Set seed set_seed() # Set page config set_page_config("HxH Character Anime Classification with Prototypical Networks", "🦸") # Sidebar name_model = st.sidebar.selectbox("Select Model", tuple(configs.NAME_MODELS.keys())) support_set_method = st.sidebar.selectbox( "Select Support Set Method", configs.SUPPORT_SET_METHODS ) freeze_model = st.sidebar.checkbox("Freeze Model", value=True) pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True) # Load Model @st.cache_resource def load_model( name_model: str, support_set_method: str, freeze_model: bool, pretrained_model: bool ): prototypical_networks = PrototypicalNetworksModel( name_model, freeze_model, pretrained_model, support_set_method ) custom_grad_cam = PrototypicalNetworksGradCAM( name_model, freeze_model, pretrained_model, support_set_method ) return prototypical_networks, custom_grad_cam prototypical_networks, custom_grad_cam = load_model( name_model, support_set_method, freeze_model, pretrained_model ) # Application Description st.markdown("# ❓ Application Description") st.write( f""" Welcome to our HxH Character Anime Classification with Prototypical Networks 🦸 app! With just a few clicks, you can classify your favorite anime characters from Hunter x Hunter using our powerful and efficient Prototypical Networks. Our user-friendly interface makes it easy for anyone to get started, whether you're a hardcore anime fan or just looking for a fun way to pass the time. Simply upload an image or select one of our default images, and let our app do the rest! Our app will accurately identify and classify the character, and even provide you with a Grad-CAM image to show you which parts of the image contributed most to the classification. So what are you waiting for? Try our HxH Character Anime Classification app now and see if you can correctly identify all your favorite characters! DISCLAIMER: The output of this app only {", ".join(configs.CLASS_CHARACTERS)} """ ) uploaded_file = st.file_uploader( "Upload image file", type=["jpg", "jpeg", "png", "bmp", "tiff"] ) select_default_images = st.selectbox("Select default images", get_default_images()) st.caption("Default Images will be used if no image is uploaded.") select_image_button = st.button("Select Image") if select_image_button: st.success("Image selected") if select_image_button and uploaded_file is not None: image = np.array(Image.open(uploaded_file).convert("RGB")) st.session_state["image"] = image elif select_image_button and uploaded_file is None: image = np.array(Image.open(select_default_images).convert("RGB")) st.session_state["image"] = image if st.session_state.get("image") is not None: image = st.session_state.get("image") col1, col2, col3 = st.columns(3) col2.write("## 📸 Preview Image") col2.image(image, use_column_width=True) predict_image_button = col2.button("Classify Image Character") generate_empty_space(2) if predict_image_button: with st.spinner("Classifying Image Character..."): result_class = prototypical_networks.predict(image) result_grad_cam = custom_grad_cam.get_grad_cam(image) inference_time = result_class["inference_time"] col1, col2, col3 = st.columns(3) col1.write("### 🙂 Source Image") col1.image(image, use_column_width=True) col2.write("### 😎 Grad CAM Image") col2.image(result_grad_cam, use_column_width=True) col3.write("### 🤔 Most Salient Object") col3.image(get_most_salient_object(image), use_column_width=True) st.write("### 📝 Result") st.write(f"Predicted Character: {result_class['character'].title()}") st.write(f"Confidence Score: {result_class['confidence'] * 100:.2f}%") st.write(f"Inference Time: {inference_time:.2f} s") st.session_state["image"] = None