File size: 7,627 Bytes
49bceed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import cv2
import numpy as np
import streamlit as st
from PIL import Image
from models.anime_face_detection_model import SingleShotDetectorModel
from models.prototypical_networks import (
PrototypicalNetworksGradCAM,
PrototypicalNetworksModel,
)
from utils import configs
from utils.functional import (
generate_empty_space,
get_default_images,
get_most_salient_object,
set_page_config,
set_seed,
)
# Set seed
set_seed()
# Set page config
set_page_config("HxH Character Anime Detection with Prototypical Networks", "π")
# Sidebar
name_model = st.sidebar.selectbox("Select Model", tuple(configs.NAME_MODELS.keys()))
support_set_method = st.sidebar.selectbox(
"Select Support Set Method", configs.SUPPORT_SET_METHODS
)
freeze_model = st.sidebar.checkbox("Freeze Model", value=True)
pretrained_model = st.sidebar.checkbox("Pretrained Model", value=True)
# Load Model
@st.cache_resource
def load_model(
name_model: str, support_set_method: str, freeze_model: bool, pretrained_model: bool
):
prototypical_networks = PrototypicalNetworksModel(
name_model, freeze_model, pretrained_model, support_set_method
)
custom_grad_cam = PrototypicalNetworksGradCAM(
name_model, freeze_model, pretrained_model, support_set_method
)
ssd_model = SingleShotDetectorModel()
return prototypical_networks, custom_grad_cam, ssd_model
prototypical_networks, custom_grad_cam, ssd_model = load_model(
name_model, support_set_method, freeze_model, pretrained_model
)
# Application Description
st.markdown("# β Application Description")
st.write(
f"""
Welcome to our HxH Character Anime Detection with Prototypical Networks application! π΅οΈββοΈπ¦ΈββοΈπ
This powerful and efficient tool allows you to quickly and accurately identify your favorite anime characters from Hunter x Hunter using state-of-the-art Prototypical Networks. Simply upload an image or select one of our default options, and let our model do the rest! With our user-friendly interface, anyone can easily classify HxH anime characters with just a few clicks.
But that's not all! Our application also features a powerful Grad-CAM visualization tool that lets you see which parts of the image the model is using to make its predictions. Plus, with lightning-fast inference times, you won't have to wait long to get your results.
Whether you're a hardcore anime fan or just looking for a fun way to pass the time, our HxH Character Anime Detection app is sure to entertain and delight. So what are you waiting for? Give it a try and see how many characters you can identify!
DISCLAIMER: The output of this app only {", ".join(configs.CLASS_CHARACTERS)}
"""
)
uploaded_file = st.file_uploader(
"Upload image file", type=["jpg", "jpeg", "png", "bmp", "tiff"]
)
select_default_images = st.selectbox("Select default images", get_default_images())
st.caption("Default Images will be used if no image is uploaded.")
select_image_button = st.button("Select Image")
if select_image_button:
st.success("Image selected")
if select_image_button and uploaded_file is not None:
image = np.array(Image.open(uploaded_file).convert("RGB"))
st.session_state["image"] = image
elif select_image_button and uploaded_file is None:
image = np.array(Image.open(select_default_images).convert("RGB"))
st.session_state["image"] = image
if st.session_state.get("image") is not None:
image = st.session_state.get("image")
col1, col2, col3 = st.columns(3)
col2.write("## πΈ Preview Image")
col2.image(image, use_column_width=True)
predict_image_button = col2.button("Detect Image Character")
generate_empty_space(2)
if predict_image_button:
with st.spinner("Detecting Image Character..."):
results_face_anime_detection = ssd_model.detect_anime_face(image)
result_grad_cam = custom_grad_cam.get_grad_cam(image)
bounding_box_image = image.copy()
inference_time = results_face_anime_detection["inference_time"]
results_anime_face = []
if results_face_anime_detection["anime_face"]:
for result in results_face_anime_detection["anime_face"]:
crop_image = image[
int(result[1]) : int(result[3]), int(result[0]) : int(result[2])
]
character = prototypical_networks.predict(crop_image)
character_grad_cam = custom_grad_cam.get_grad_cam(
crop_image,
)
results_anime_face.append(
{
"face": crop_image,
"face_grad_cam": character_grad_cam,
"most_salient_object": get_most_salient_object(crop_image),
"character": character["character"],
"confidence_detection": result[4],
"confidence_classification": character["confidence"],
}
)
inference_time += character["inference_time"]
cv2.rectangle(
bounding_box_image,
(int(result[0]), int(result[1])),
(int(result[2]), int(result[3])),
(255, 255, 0),
4,
)
cv2.putText(
bounding_box_image,
character["character"],
(int(result[0]), int(result[1]) - 10),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255, 255, 0),
2,
)
col1, col2, col3, col4 = st.columns(4)
col1.write("### π Source Image")
col1.image(image, use_column_width=True)
col2.write("### π Detected Image")
col2.image(bounding_box_image, use_column_width=True)
col3.write("### π Grad CAM Image")
col3.image(result_grad_cam, use_column_width=True)
col4.write("### π€ Most Salient Object")
col4.image(get_most_salient_object(image), use_column_width=True)
st.write("### π Result")
st.write(f"Inference Time: {inference_time:.2f} s")
for result in results_anime_face:
col1, col2, col3 = st.columns(3)
col1.write("#### π Cropped Face Image")
col1.image(result["face"], use_column_width=True)
col2.write("#### π Cropped Face Grad CAM Image")
col2.image(result["face_grad_cam"], use_column_width=True)
col3.write("### π€ Most Salient Object")
col3.image(
get_most_salient_object(result["most_salient_object"]),
use_column_width=True,
)
st.write(f"Character: {result['character'].title()}")
st.write(
f"Confidence Score Detection: {result['confidence_detection']*100:.2f}%"
)
st.write(
f"Confidence Score Classification: {result['confidence_classification']*100:.2f}%"
)
generate_empty_space(2)
st.session_state["image"] = None
|