hafidhsoekma
commited on
Commit
•
49bceed
1
Parent(s):
e7eede8
First commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +34 -34
- .gitignore +163 -0
- 01-🚀 Homepage.py +149 -0
- README.md +77 -13
- assets/example_images/gon/306e5d35-b301-4299-8022-0c89dc0b7690.png +0 -0
- assets/example_images/gon/3509df87-a9cd-4500-a07a-0373cbe36715.png +0 -0
- assets/example_images/gon/620c33e9-59fe-418d-8953-1444e3cfa599.png +0 -0
- assets/example_images/gon/ab07531a-ab8f-445f-8e9d-d478dd67a73b.png +0 -0
- assets/example_images/gon/df746603-8dd9-4397-92d8-bf49f8df20d2.png +0 -0
- assets/example_images/hisoka/04af934d-ffb5-4cc9-83ad-88c585678e55.png +0 -0
- assets/example_images/hisoka/13d7867a-28e0-45f0-b141-8b8624d0e1e5.png +0 -0
- assets/example_images/hisoka/41954fdc-d740-49ec-a7ba-15cac7c22c11.png +0 -0
- assets/example_images/hisoka/422e9625-c523-4532-aa5b-dd4e21b209fc.png +0 -0
- assets/example_images/hisoka/80f95e87-2f7a-4808-9d01-4383feab90e2.png +0 -0
- assets/example_images/killua/0d2a44c4-c11e-474e-ac8b-7c0e84c7f879.png +0 -0
- assets/example_images/killua/2817e633-3239-41f1-a2bf-1be874bddf5e.png +0 -0
- assets/example_images/killua/4501242f-9bda-49b6-a3c5-23f97c8353c3.png +0 -0
- assets/example_images/killua/8aca13ab-a5b2-4192-ae4b-3b73e8c663f3.png +0 -0
- assets/example_images/killua/8b7e1854-8ca7-4ef1-8887-2c64b0309712.png +0 -0
- assets/example_images/kurapika/02265b41-9833-41eb-ad60-e043753f74b9.png +0 -0
- assets/example_images/kurapika/0650e968-d61b-4c4a-98bd-7ecdd2b991de.png +0 -0
- assets/example_images/kurapika/2728dfb5-788b-4be7-ad1b-e6d23297ecf3.png +0 -0
- assets/example_images/kurapika/3613a920-3efe-49d8-a39a-227bddefa86a.png +0 -0
- assets/example_images/kurapika/405b19b0-d982-44aa-b4c8-18e3a5e373b3.png +0 -0
- assets/example_images/leorio/00beabbf-063e-42b3-85e2-ce51c586195f.png +0 -0
- assets/example_images/leorio/613e8ffb-7534-481d-b780-6d23ecd31de4.png +0 -0
- assets/example_images/leorio/af2a59f2-fcf2-4621-bb4f-6540687b390a.png +0 -0
- assets/example_images/leorio/b134831a-5ee0-40c8-9a25-1a11329741d3.png +0 -0
- assets/example_images/leorio/ccc511a0-8a98-481c-97a1-c564a874bb60.png +0 -0
- assets/example_images/others/Presiden_Sukarno.jpg +0 -0
- assets/example_images/others/Tipe-Nen-yang-ada-di-Anime-Hunter-x-Hunter.jpg +0 -0
- assets/example_images/others/d29492bbe7604505a6f1b5394f62b393.png +0 -0
- assets/example_images/others/f575c3a5f23146b59bac51267db0ddb3.png +0 -0
- assets/example_images/others/fa4548a8f57041edb7fa19f8bf302326.png +0 -0
- assets/example_images/others/fb7c8048d54f48a29ab6aaf7f8383712.png +0 -0
- assets/example_images/others/fe96e8fce17b474195f8add2632b758e.png +0 -0
- assets/images/author.jpg +0 -0
- models/anime_face_detection_model/__init__.py +1 -0
- models/anime_face_detection_model/ssd_model.py +454 -0
- models/base_model/__init__.py +4 -0
- models/base_model/grad_cam.py +126 -0
- models/base_model/image_embeddings.py +67 -0
- models/base_model/image_similarity.py +86 -0
- models/base_model/main_model.py +52 -0
- models/deep_learning/__init__.py +4 -0
- models/deep_learning/backbone_model.py +109 -0
- models/deep_learning/deep_learning.py +90 -0
- models/deep_learning/grad_cam.py +59 -0
- models/deep_learning/image_embeddings.py +58 -0
- models/deep_learning/image_similarity.py +63 -0
.gitattributes
CHANGED
@@ -1,34 +1,34 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
.python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
# Custom gitignore
|
163 |
+
run_app.sh
|
01-🚀 Homepage.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit_extras.switch_page_button import switch_page
|
3 |
+
|
4 |
+
from utils.functional import generate_empty_space, set_page_config
|
5 |
+
|
6 |
+
# Set page config
|
7 |
+
set_page_config("Homepage", "🚀")
|
8 |
+
|
9 |
+
# First Header
|
10 |
+
st.markdown("# 😊 About Me")
|
11 |
+
st.write(
|
12 |
+
"""
|
13 |
+
👋 Hello everyone! My name is Hafidh Soekma Ardiansyah and I'm a student at Surabaya State University, majoring in Management Information Vocational Programs. 🎓
|
14 |
+
|
15 |
+
I am excited to share with you all about my final project for the semester. 📚 My project is about classifying anime characters from the popular Hunter X Hunter anime series using various machine learning algorithms. 🤖
|
16 |
+
|
17 |
+
To start the project, I collected a dataset of images featuring the characters from the series. 📷 Then, I preprocessed the data to ensure that the algorithms could efficiently process it. 💻
|
18 |
+
|
19 |
+
After the data preparation, I used various algorithms such as Deep Learning, Prototypical Networks, and many more to classify the characters. 🧠
|
20 |
+
|
21 |
+
Through this project, I hope to showcase my skills in machine learning and contribute to the community of anime fans who are interested in image classification. 🙌
|
22 |
+
|
23 |
+
Thank you for your attention, and please feel free to ask me any questions about the project! 🤗
|
24 |
+
|
25 |
+
"""
|
26 |
+
)
|
27 |
+
|
28 |
+
st.markdown("# 🕵️ About the Project")
|
29 |
+
|
30 |
+
st.markdown("### 🦸 HxH Character Anime Classification with Prototypical Networks")
|
31 |
+
st.write(
|
32 |
+
"Classify your favorite Hunter x Hunter characters with our cutting-edge Prototypical Networks! 🦸♂️🦸♀️"
|
33 |
+
)
|
34 |
+
go_to_page_0 = st.button(
|
35 |
+
"Go to page 0",
|
36 |
+
)
|
37 |
+
generate_empty_space(2)
|
38 |
+
if go_to_page_0:
|
39 |
+
switch_page("hxh character anime classification with prototypical networks")
|
40 |
+
|
41 |
+
st.markdown("### 🔎 HxH Character Anime Detection with Prototypical Networks")
|
42 |
+
st.write(
|
43 |
+
"Detect the presence of your beloved Hunter x Hunter characters using Prototypical Networks! 🔎🕵️♂️🕵️♀️"
|
44 |
+
)
|
45 |
+
go_to_page_1 = st.button(
|
46 |
+
"Go to page 1",
|
47 |
+
)
|
48 |
+
generate_empty_space(2)
|
49 |
+
if go_to_page_1:
|
50 |
+
switch_page("hxh character anime detection with prototypical networks")
|
51 |
+
|
52 |
+
st.markdown("### 📊 Image Similarity with Prototypical Networks")
|
53 |
+
st.write(
|
54 |
+
"Discover how similar your Images are to one another with our Prototypical Networks! 📊🤔"
|
55 |
+
)
|
56 |
+
go_to_page_2 = st.button(
|
57 |
+
"Go to page 2",
|
58 |
+
)
|
59 |
+
generate_empty_space(2)
|
60 |
+
if go_to_page_2:
|
61 |
+
switch_page("image similarity with prototypical networks")
|
62 |
+
|
63 |
+
st.markdown("### 🌌 Image Embeddings with Prototypical Networks")
|
64 |
+
st.write(
|
65 |
+
"Unleash the power of image embeddings to represent Images in a whole new way with our Prototypical Networks! 🌌🤯"
|
66 |
+
)
|
67 |
+
go_to_page_3 = st.button(
|
68 |
+
"Go to page 3",
|
69 |
+
)
|
70 |
+
generate_empty_space(2)
|
71 |
+
if go_to_page_3:
|
72 |
+
switch_page("image embeddings with prototypical networks")
|
73 |
+
|
74 |
+
st.markdown("### 🤖 HxH Character Anime Classification with Deep Learning")
|
75 |
+
st.write(
|
76 |
+
"Experience the next level of character classification with our Deep Learning models trained on Hunter x Hunter anime characters! 🤖📈"
|
77 |
+
)
|
78 |
+
go_to_page_4 = st.button(
|
79 |
+
"Go to page 4",
|
80 |
+
)
|
81 |
+
generate_empty_space(2)
|
82 |
+
if go_to_page_4:
|
83 |
+
switch_page("hxh character anime classification with deep learning")
|
84 |
+
|
85 |
+
st.markdown("### 📷 HxH Character Anime Detection with Deep Learning")
|
86 |
+
st.write(
|
87 |
+
"Detect your favorite Hunter x Hunter characters with our Deep Learning models! 📷🕵️♂️🕵️♀️"
|
88 |
+
)
|
89 |
+
go_to_page_5 = st.button(
|
90 |
+
"Go to page 5",
|
91 |
+
)
|
92 |
+
generate_empty_space(2)
|
93 |
+
if go_to_page_5:
|
94 |
+
switch_page("hxh character anime detection with deep learning")
|
95 |
+
|
96 |
+
st.markdown("### 🖼️ Image Similarity with Deep Learning")
|
97 |
+
st.write(
|
98 |
+
"Discover the similarities and differences between your Images with our Deep Learning models! 🖼️🧐"
|
99 |
+
)
|
100 |
+
go_to_page_6 = st.button(
|
101 |
+
"Go to page 6",
|
102 |
+
)
|
103 |
+
generate_empty_space(2)
|
104 |
+
if go_to_page_6:
|
105 |
+
switch_page("image similarity with deep learning")
|
106 |
+
|
107 |
+
st.markdown("### 📈 Image Embeddings with Deep Learning")
|
108 |
+
st.write(
|
109 |
+
"Explore a new dimension of Images representations with our Deep Learning-based image embeddings! 📈🔍"
|
110 |
+
)
|
111 |
+
go_to_page_7 = st.button(
|
112 |
+
"Go to page 7",
|
113 |
+
)
|
114 |
+
generate_empty_space(2)
|
115 |
+
if go_to_page_7:
|
116 |
+
switch_page("image embeddings with deep learning")
|
117 |
+
|
118 |
+
st.markdown("### 🎯 Zero-Shot Image Classification with CLIP")
|
119 |
+
st.write(
|
120 |
+
"Classify Images with zero training using CLIP, a state-of-the-art language-image model! 🎯🤯"
|
121 |
+
)
|
122 |
+
go_to_page_8 = st.button(
|
123 |
+
"Go to page 8",
|
124 |
+
)
|
125 |
+
generate_empty_space(2)
|
126 |
+
if go_to_page_8:
|
127 |
+
switch_page("zero-shot image classification with clip")
|
128 |
+
|
129 |
+
st.markdown("### 😊 More About Me")
|
130 |
+
st.write(
|
131 |
+
"Curious to learn more about the person behind these amazing projects? Check out my bio and get to know me better! 😊🧑💼"
|
132 |
+
)
|
133 |
+
go_to_page_9 = st.button(
|
134 |
+
"Go to page 9",
|
135 |
+
)
|
136 |
+
generate_empty_space(2)
|
137 |
+
if go_to_page_9:
|
138 |
+
switch_page("more about me")
|
139 |
+
|
140 |
+
st.markdown("### 📚 Glossary")
|
141 |
+
st.write(
|
142 |
+
"Not sure what some of the terms used in this project mean? Check out our glossary to learn more! 📚🤓"
|
143 |
+
)
|
144 |
+
go_to_page_10 = st.button(
|
145 |
+
"Go to page 10",
|
146 |
+
)
|
147 |
+
generate_empty_space(2)
|
148 |
+
if go_to_page_10:
|
149 |
+
switch_page("glossary")
|
README.md
CHANGED
@@ -1,13 +1,77 @@
|
|
1 |
-
---
|
2 |
-
title: Hunter X Hunter Anime Classification
|
3 |
-
emoji: 🔥
|
4 |
-
colorFrom:
|
5 |
-
colorTo: green
|
6 |
-
sdk: streamlit
|
7 |
-
sdk_version: 1.19.0
|
8 |
-
app_file:
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Hunter X Hunter Anime Classification
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: white
|
5 |
+
colorTo: green
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.19.0
|
8 |
+
app_file: "01-🚀 Homepage.py"
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
python_version: 3.9.13
|
12 |
+
---
|
13 |
+
|
14 |
+
# Hunter X Hunter Anime Classification
|
15 |
+
|
16 |
+
Welcome to the Hunter X Hunter Anime Classification application! This project focuses on classifying anime characters from the popular Hunter X Hunter anime series using various machine learning algorithms.
|
17 |
+
|
18 |
+
## About Me
|
19 |
+
|
20 |
+
👋 Hello everyone! My name is Hafidh Soekma Ardiansyah, and I'm a student at Surabaya State University, majoring in Management Information Vocational Programs.
|
21 |
+
|
22 |
+
I am excited to share with you all about my final project for the semester. My project is about classifying anime characters from the Hunter X Hunter series using various machine learning algorithms. To accomplish this, I collected a dataset of images featuring the characters from the series and preprocessed the data to ensure efficient processing by the algorithms.
|
23 |
+
|
24 |
+
## About the Project
|
25 |
+
|
26 |
+
### HxH Character Anime Classification with Prototypical Networks
|
27 |
+
|
28 |
+
Classify your favorite Hunter x Hunter characters with our cutting-edge Prototypical Networks! 🦸♂️🦸♀️
|
29 |
+
|
30 |
+
### HxH Character Anime Detection with Prototypical Networks
|
31 |
+
|
32 |
+
Detect the presence of your beloved Hunter x Hunter characters using Prototypical Networks! 🔎🕵️♂️🕵️♀️
|
33 |
+
|
34 |
+
### Image Similarity with Prototypical Networks
|
35 |
+
|
36 |
+
Discover how similar your images are to one another with our Prototypical Networks! 📊🤔
|
37 |
+
|
38 |
+
### Image Embeddings with Prototypical Networks
|
39 |
+
|
40 |
+
Unleash the power of image embeddings to represent images in a whole new way with our Prototypical Networks! 🌌🤯
|
41 |
+
|
42 |
+
### HxH Character Anime Classification with Deep Learning
|
43 |
+
|
44 |
+
Experience the next level of character classification with our Deep Learning models trained on Hunter x Hunter anime characters! 🤖📈
|
45 |
+
|
46 |
+
### HxH Character Anime Detection with Deep Learning
|
47 |
+
|
48 |
+
Detect your favorite Hunter x Hunter characters with our Deep Learning models! 📷🕵️♂️🕵️♀️
|
49 |
+
|
50 |
+
### Image Similarity with Deep Learning
|
51 |
+
|
52 |
+
Discover the similarities and differences between your images with our Deep Learning models! 🖼️🧐
|
53 |
+
|
54 |
+
### Image Embeddings with Deep Learning
|
55 |
+
|
56 |
+
Explore a new dimension of image representations with our Deep Learning-based image embeddings! 📈🔍
|
57 |
+
|
58 |
+
### Zero-Shot Image Classification with CLIP
|
59 |
+
|
60 |
+
Classify images with zero training using CLIP, a state-of-the-art language-image model! 🎯🤯
|
61 |
+
|
62 |
+
### More About Me
|
63 |
+
|
64 |
+
Curious to learn more about the person behind these amazing projects? Check out my bio and get to know me better! 😊🧑💼
|
65 |
+
|
66 |
+
### Glossary
|
67 |
+
|
68 |
+
Not sure what some of the terms used in this project mean? Check out our glossary to learn more! 📚🤓
|
69 |
+
|
70 |
+
## How to Run the Application
|
71 |
+
|
72 |
+
1. Clone the repository: `git clone hhttps://huggingface.co/spaces/hafidhsoekma/Hunter-X-Hunter-Anime-Classification`
|
73 |
+
2. Install the required dependencies: `pip install -r requirements.txt`
|
74 |
+
3. Run the application: `streamlit run "01-🚀 Homepage.py"`
|
75 |
+
4. Open your web browser and navigate to the provided URL to access the application.
|
76 |
+
|
77 |
+
Feel free to reach out to me if you have any questions or feedback. Enjoy exploring the Hunter X Hunter Anime Classification application!
|
assets/example_images/gon/306e5d35-b301-4299-8022-0c89dc0b7690.png
ADDED
assets/example_images/gon/3509df87-a9cd-4500-a07a-0373cbe36715.png
ADDED
assets/example_images/gon/620c33e9-59fe-418d-8953-1444e3cfa599.png
ADDED
assets/example_images/gon/ab07531a-ab8f-445f-8e9d-d478dd67a73b.png
ADDED
assets/example_images/gon/df746603-8dd9-4397-92d8-bf49f8df20d2.png
ADDED
assets/example_images/hisoka/04af934d-ffb5-4cc9-83ad-88c585678e55.png
ADDED
assets/example_images/hisoka/13d7867a-28e0-45f0-b141-8b8624d0e1e5.png
ADDED
assets/example_images/hisoka/41954fdc-d740-49ec-a7ba-15cac7c22c11.png
ADDED
assets/example_images/hisoka/422e9625-c523-4532-aa5b-dd4e21b209fc.png
ADDED
assets/example_images/hisoka/80f95e87-2f7a-4808-9d01-4383feab90e2.png
ADDED
assets/example_images/killua/0d2a44c4-c11e-474e-ac8b-7c0e84c7f879.png
ADDED
assets/example_images/killua/2817e633-3239-41f1-a2bf-1be874bddf5e.png
ADDED
assets/example_images/killua/4501242f-9bda-49b6-a3c5-23f97c8353c3.png
ADDED
assets/example_images/killua/8aca13ab-a5b2-4192-ae4b-3b73e8c663f3.png
ADDED
assets/example_images/killua/8b7e1854-8ca7-4ef1-8887-2c64b0309712.png
ADDED
assets/example_images/kurapika/02265b41-9833-41eb-ad60-e043753f74b9.png
ADDED
assets/example_images/kurapika/0650e968-d61b-4c4a-98bd-7ecdd2b991de.png
ADDED
assets/example_images/kurapika/2728dfb5-788b-4be7-ad1b-e6d23297ecf3.png
ADDED
assets/example_images/kurapika/3613a920-3efe-49d8-a39a-227bddefa86a.png
ADDED
assets/example_images/kurapika/405b19b0-d982-44aa-b4c8-18e3a5e373b3.png
ADDED
assets/example_images/leorio/00beabbf-063e-42b3-85e2-ce51c586195f.png
ADDED
assets/example_images/leorio/613e8ffb-7534-481d-b780-6d23ecd31de4.png
ADDED
assets/example_images/leorio/af2a59f2-fcf2-4621-bb4f-6540687b390a.png
ADDED
assets/example_images/leorio/b134831a-5ee0-40c8-9a25-1a11329741d3.png
ADDED
assets/example_images/leorio/ccc511a0-8a98-481c-97a1-c564a874bb60.png
ADDED
assets/example_images/others/Presiden_Sukarno.jpg
ADDED
assets/example_images/others/Tipe-Nen-yang-ada-di-Anime-Hunter-x-Hunter.jpg
ADDED
assets/example_images/others/d29492bbe7604505a6f1b5394f62b393.png
ADDED
assets/example_images/others/f575c3a5f23146b59bac51267db0ddb3.png
ADDED
assets/example_images/others/fa4548a8f57041edb7fa19f8bf302326.png
ADDED
assets/example_images/others/fb7c8048d54f48a29ab6aaf7f8383712.png
ADDED
assets/example_images/others/fe96e8fce17b474195f8add2632b758e.png
ADDED
assets/images/author.jpg
ADDED
models/anime_face_detection_model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .ssd_model import SingleShotDetectorModel
|
models/anime_face_detection_model/ssd_model.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
5 |
+
|
6 |
+
import time
|
7 |
+
from itertools import product as product
|
8 |
+
from math import ceil
|
9 |
+
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
|
17 |
+
class BasicConv2d(nn.Module):
|
18 |
+
def __init__(self, in_channels, out_channels, **kwargs):
|
19 |
+
super(BasicConv2d, self).__init__()
|
20 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
21 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x = self.conv(x)
|
25 |
+
x = self.bn(x)
|
26 |
+
return F.relu(x, inplace=True)
|
27 |
+
|
28 |
+
|
29 |
+
class Inception(nn.Module):
|
30 |
+
def __init__(self):
|
31 |
+
super(Inception, self).__init__()
|
32 |
+
self.branch1x1 = BasicConv2d(128, 32, kernel_size=1, padding=0)
|
33 |
+
self.branch1x1_2 = BasicConv2d(128, 32, kernel_size=1, padding=0)
|
34 |
+
self.branch3x3_reduce = BasicConv2d(128, 24, kernel_size=1, padding=0)
|
35 |
+
self.branch3x3 = BasicConv2d(24, 32, kernel_size=3, padding=1)
|
36 |
+
self.branch3x3_reduce_2 = BasicConv2d(128, 24, kernel_size=1, padding=0)
|
37 |
+
self.branch3x3_2 = BasicConv2d(24, 32, kernel_size=3, padding=1)
|
38 |
+
self.branch3x3_3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
branch1x1 = self.branch1x1(x)
|
42 |
+
|
43 |
+
branch1x1_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
|
44 |
+
branch1x1_2 = self.branch1x1_2(branch1x1_pool)
|
45 |
+
|
46 |
+
branch3x3_reduce = self.branch3x3_reduce(x)
|
47 |
+
branch3x3 = self.branch3x3(branch3x3_reduce)
|
48 |
+
|
49 |
+
branch3x3_reduce_2 = self.branch3x3_reduce_2(x)
|
50 |
+
branch3x3_2 = self.branch3x3_2(branch3x3_reduce_2)
|
51 |
+
branch3x3_3 = self.branch3x3_3(branch3x3_2)
|
52 |
+
|
53 |
+
outputs = (branch1x1, branch1x1_2, branch3x3, branch3x3_3)
|
54 |
+
return torch.cat(outputs, 1)
|
55 |
+
|
56 |
+
|
57 |
+
class CRelu(nn.Module):
|
58 |
+
def __init__(self, in_channels, out_channels, **kwargs):
|
59 |
+
super(CRelu, self).__init__()
|
60 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
61 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
x = self.conv(x)
|
65 |
+
x = self.bn(x)
|
66 |
+
x = torch.cat((x, -x), 1)
|
67 |
+
x = F.relu(x, inplace=True)
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class FaceBoxes(nn.Module):
|
72 |
+
def __init__(self, phase, size, num_classes):
|
73 |
+
super(FaceBoxes, self).__init__()
|
74 |
+
self.phase = phase
|
75 |
+
self.num_classes = num_classes
|
76 |
+
self.size = size
|
77 |
+
|
78 |
+
self.conv1 = CRelu(3, 24, kernel_size=7, stride=4, padding=3)
|
79 |
+
self.conv2 = CRelu(48, 64, kernel_size=5, stride=2, padding=2)
|
80 |
+
|
81 |
+
self.inception1 = Inception()
|
82 |
+
self.inception2 = Inception()
|
83 |
+
self.inception3 = Inception()
|
84 |
+
|
85 |
+
self.conv3_1 = BasicConv2d(128, 128, kernel_size=1, stride=1, padding=0)
|
86 |
+
self.conv3_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
87 |
+
|
88 |
+
self.conv4_1 = BasicConv2d(256, 128, kernel_size=1, stride=1, padding=0)
|
89 |
+
self.conv4_2 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
90 |
+
|
91 |
+
self.loc, self.conf = self.multibox(self.num_classes)
|
92 |
+
|
93 |
+
if self.phase == "test":
|
94 |
+
self.softmax = nn.Softmax(dim=-1)
|
95 |
+
|
96 |
+
if self.phase == "train":
|
97 |
+
for m in self.modules():
|
98 |
+
if isinstance(m, nn.Conv2d):
|
99 |
+
if m.bias is not None:
|
100 |
+
nn.init.xavier_normal_(m.weight.data)
|
101 |
+
m.bias.data.fill_(0.02)
|
102 |
+
else:
|
103 |
+
m.weight.data.normal_(0, 0.01)
|
104 |
+
elif isinstance(m, nn.BatchNorm2d):
|
105 |
+
m.weight.data.fill_(1)
|
106 |
+
m.bias.data.zero_()
|
107 |
+
|
108 |
+
def multibox(self, num_classes):
|
109 |
+
loc_layers = []
|
110 |
+
conf_layers = []
|
111 |
+
loc_layers += [nn.Conv2d(128, 21 * 4, kernel_size=3, padding=1)]
|
112 |
+
conf_layers += [nn.Conv2d(128, 21 * num_classes, kernel_size=3, padding=1)]
|
113 |
+
loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
|
114 |
+
conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
|
115 |
+
loc_layers += [nn.Conv2d(256, 1 * 4, kernel_size=3, padding=1)]
|
116 |
+
conf_layers += [nn.Conv2d(256, 1 * num_classes, kernel_size=3, padding=1)]
|
117 |
+
return nn.Sequential(*loc_layers), nn.Sequential(*conf_layers)
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
detection_sources = list()
|
121 |
+
loc = list()
|
122 |
+
conf = list()
|
123 |
+
|
124 |
+
x = self.conv1(x)
|
125 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
|
126 |
+
x = self.conv2(x)
|
127 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
|
128 |
+
x = self.inception1(x)
|
129 |
+
x = self.inception2(x)
|
130 |
+
x = self.inception3(x)
|
131 |
+
detection_sources.append(x)
|
132 |
+
|
133 |
+
x = self.conv3_1(x)
|
134 |
+
x = self.conv3_2(x)
|
135 |
+
detection_sources.append(x)
|
136 |
+
|
137 |
+
x = self.conv4_1(x)
|
138 |
+
x = self.conv4_2(x)
|
139 |
+
detection_sources.append(x)
|
140 |
+
|
141 |
+
for x, l, c in zip(detection_sources, self.loc, self.conf):
|
142 |
+
loc.append(l(x).permute(0, 2, 3, 1).contiguous())
|
143 |
+
conf.append(c(x).permute(0, 2, 3, 1).contiguous())
|
144 |
+
|
145 |
+
loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
|
146 |
+
conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
|
147 |
+
|
148 |
+
if self.phase == "test":
|
149 |
+
output = (
|
150 |
+
loc.view(loc.size(0), -1, 4),
|
151 |
+
self.softmax(conf.view(-1, self.num_classes)),
|
152 |
+
)
|
153 |
+
else:
|
154 |
+
output = (
|
155 |
+
loc.view(loc.size(0), -1, 4),
|
156 |
+
conf.view(conf.size(0), -1, self.num_classes),
|
157 |
+
)
|
158 |
+
|
159 |
+
return output
|
160 |
+
|
161 |
+
|
162 |
+
class PriorBox(object):
|
163 |
+
def __init__(self, cfg, image_size=None, phase="train"):
|
164 |
+
super(PriorBox, self).__init__()
|
165 |
+
# self.aspect_ratios = cfg['aspect_ratios']
|
166 |
+
self.min_sizes = cfg["min_sizes"]
|
167 |
+
self.steps = cfg["steps"]
|
168 |
+
self.clip = cfg["clip"]
|
169 |
+
self.image_size = image_size
|
170 |
+
self.feature_maps = [
|
171 |
+
(ceil(self.image_size[0] / step), ceil(self.image_size[1] / step))
|
172 |
+
for step in self.steps
|
173 |
+
]
|
174 |
+
self.feature_maps = tuple(self.feature_maps)
|
175 |
+
|
176 |
+
def forward(self):
|
177 |
+
anchors = []
|
178 |
+
for k, f in enumerate(self.feature_maps):
|
179 |
+
min_sizes = self.min_sizes[k]
|
180 |
+
for i, j in product(range(f[0]), range(f[1])):
|
181 |
+
for min_size in min_sizes:
|
182 |
+
s_kx = min_size / self.image_size[1]
|
183 |
+
s_ky = min_size / self.image_size[0]
|
184 |
+
if min_size == 32:
|
185 |
+
dense_cx = [
|
186 |
+
x * self.steps[k] / self.image_size[1]
|
187 |
+
for x in [j + 0, j + 0.25, j + 0.5, j + 0.75]
|
188 |
+
]
|
189 |
+
dense_cy = [
|
190 |
+
y * self.steps[k] / self.image_size[0]
|
191 |
+
for y in [i + 0, i + 0.25, i + 0.5, i + 0.75]
|
192 |
+
]
|
193 |
+
for cy, cx in product(dense_cy, dense_cx):
|
194 |
+
anchors += [cx, cy, s_kx, s_ky]
|
195 |
+
elif min_size == 64:
|
196 |
+
dense_cx = [
|
197 |
+
x * self.steps[k] / self.image_size[1]
|
198 |
+
for x in [j + 0, j + 0.5]
|
199 |
+
]
|
200 |
+
dense_cy = [
|
201 |
+
y * self.steps[k] / self.image_size[0]
|
202 |
+
for y in [i + 0, i + 0.5]
|
203 |
+
]
|
204 |
+
for cy, cx in product(dense_cy, dense_cx):
|
205 |
+
anchors += [cx, cy, s_kx, s_ky]
|
206 |
+
else:
|
207 |
+
cx = (j + 0.5) * self.steps[k] / self.image_size[1]
|
208 |
+
cy = (i + 0.5) * self.steps[k] / self.image_size[0]
|
209 |
+
anchors += [cx, cy, s_kx, s_ky]
|
210 |
+
# back to torch land
|
211 |
+
output = torch.Tensor(anchors).view(-1, 4)
|
212 |
+
if self.clip:
|
213 |
+
output.clamp_(max=1, min=0)
|
214 |
+
return output
|
215 |
+
|
216 |
+
|
217 |
+
def mymax(a, b):
|
218 |
+
if a >= b:
|
219 |
+
return a
|
220 |
+
else:
|
221 |
+
return b
|
222 |
+
|
223 |
+
|
224 |
+
def mymin(a, b):
|
225 |
+
if a >= b:
|
226 |
+
return b
|
227 |
+
else:
|
228 |
+
return a
|
229 |
+
|
230 |
+
|
231 |
+
def cpu_nms(dets, thresh):
|
232 |
+
x1 = dets[:, 0]
|
233 |
+
y1 = dets[:, 1]
|
234 |
+
x2 = dets[:, 2]
|
235 |
+
y2 = dets[:, 3]
|
236 |
+
scores = dets[:, 4]
|
237 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
238 |
+
order = scores.argsort()[::-1]
|
239 |
+
ndets = dets.shape[0]
|
240 |
+
suppressed = np.zeros((ndets), dtype=int)
|
241 |
+
keep = []
|
242 |
+
for _i in range(ndets):
|
243 |
+
i = order[_i]
|
244 |
+
if suppressed[i] == 1:
|
245 |
+
continue
|
246 |
+
keep.append(i)
|
247 |
+
ix1 = x1[i]
|
248 |
+
iy1 = y1[i]
|
249 |
+
ix2 = x2[i]
|
250 |
+
iy2 = y2[i]
|
251 |
+
iarea = areas[i]
|
252 |
+
for _j in range(_i + 1, ndets):
|
253 |
+
j = order[_j]
|
254 |
+
if suppressed[j] == 1:
|
255 |
+
continue
|
256 |
+
xx1 = mymax(ix1, x1[j])
|
257 |
+
yy1 = mymax(iy1, y1[j])
|
258 |
+
xx2 = mymin(ix2, x2[j])
|
259 |
+
yy2 = mymin(iy2, y2[j])
|
260 |
+
w = mymax(0.0, xx2 - xx1 + 1)
|
261 |
+
h = mymax(0.0, yy2 - yy1 + 1)
|
262 |
+
inter = w * h
|
263 |
+
ovr = inter / (iarea + areas[j] - inter)
|
264 |
+
if ovr >= thresh:
|
265 |
+
suppressed[j] = 1
|
266 |
+
return tuple(keep)
|
267 |
+
|
268 |
+
|
269 |
+
def nms(dets, thresh, force_cpu=False):
|
270 |
+
"""Dispatch to either CPU or GPU NMS implementations."""
|
271 |
+
|
272 |
+
if dets.shape[0] == 0:
|
273 |
+
return ()
|
274 |
+
if force_cpu:
|
275 |
+
# return cpu_soft_nms(dets, thresh, method = 0)
|
276 |
+
return cpu_nms(dets, thresh)
|
277 |
+
return cpu_nms(dets, thresh)
|
278 |
+
|
279 |
+
|
280 |
+
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
281 |
+
def decode(loc, priors, variances):
|
282 |
+
"""Decode locations from predictions using priors to undo
|
283 |
+
the encoding we did for offset regression at train time.
|
284 |
+
Args:
|
285 |
+
loc (tensor): location predictions for loc layers,
|
286 |
+
Shape: [num_priors,4]
|
287 |
+
priors (tensor): Prior boxes in center-offset form.
|
288 |
+
Shape: [num_priors,4].
|
289 |
+
variances: (list[float]) Variances of priorboxes
|
290 |
+
Return:
|
291 |
+
decoded bounding box predictions
|
292 |
+
"""
|
293 |
+
|
294 |
+
boxes = torch.cat(
|
295 |
+
(
|
296 |
+
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
297 |
+
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1]),
|
298 |
+
),
|
299 |
+
1,
|
300 |
+
)
|
301 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
302 |
+
boxes[:, 2:] += boxes[:, :2]
|
303 |
+
return boxes
|
304 |
+
|
305 |
+
|
306 |
+
def check_keys(model, pretrained_state_dict):
|
307 |
+
ckpt_keys = set(pretrained_state_dict.keys())
|
308 |
+
model_keys = set(model.state_dict().keys())
|
309 |
+
used_pretrained_keys = model_keys & ckpt_keys
|
310 |
+
unused_pretrained_keys = ckpt_keys - model_keys
|
311 |
+
missing_keys = model_keys - ckpt_keys
|
312 |
+
# print('Missing keys:{}'.format(len(missing_keys)))
|
313 |
+
# print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
|
314 |
+
# print('Used keys:{}'.format(len(used_pretrained_keys)))
|
315 |
+
assert len(used_pretrained_keys) > 0, "load NONE from pretrained checkpoint"
|
316 |
+
return True
|
317 |
+
|
318 |
+
|
319 |
+
def remove_prefix(state_dict, prefix):
|
320 |
+
"""Old style model is stored with all names of parameters sharing common prefix 'module.'"""
|
321 |
+
|
322 |
+
# print('remove prefix \'{}\''.format(prefix))
|
323 |
+
def f(x):
|
324 |
+
return x.split(prefix, 1)[-1] if x.startswith(prefix) else x
|
325 |
+
|
326 |
+
return {f(key): value for key, value in state_dict.items()}
|
327 |
+
|
328 |
+
|
329 |
+
def load_model(model, pretrained_path, load_to_cpu):
|
330 |
+
# print('Loading pretrained model from {}'.format(pretrained_path))
|
331 |
+
if load_to_cpu:
|
332 |
+
pretrained_dict = torch.load(
|
333 |
+
pretrained_path, map_location=lambda storage, loc: storage
|
334 |
+
)
|
335 |
+
else:
|
336 |
+
device = torch.cuda.current_device()
|
337 |
+
pretrained_dict = torch.load(
|
338 |
+
pretrained_path, map_location=lambda storage, loc: storage.cuda(device)
|
339 |
+
)
|
340 |
+
if "state_dict" in pretrained_dict.keys():
|
341 |
+
pretrained_dict = remove_prefix(pretrained_dict["state_dict"], "module.")
|
342 |
+
else:
|
343 |
+
pretrained_dict = remove_prefix(pretrained_dict, "module.")
|
344 |
+
check_keys(model, pretrained_dict)
|
345 |
+
model.load_state_dict(pretrained_dict, strict=False)
|
346 |
+
return model
|
347 |
+
|
348 |
+
|
349 |
+
class SingleShotDetectorModel:
|
350 |
+
def __init__(
|
351 |
+
self,
|
352 |
+
path_to_weights: str = "./weights/anime_face_detection/ssd_anime_face_detect.pth",
|
353 |
+
confidence_threshold: float = 0.5,
|
354 |
+
nms_threshold: float = 0.3,
|
355 |
+
top_k: int = 5000,
|
356 |
+
keep_top_k: int = 750,
|
357 |
+
):
|
358 |
+
self.path_to_weights = path_to_weights
|
359 |
+
self.confidence_threshold = confidence_threshold
|
360 |
+
self.nms_threshold = nms_threshold
|
361 |
+
self.top_k = top_k
|
362 |
+
self.keep_top_k = keep_top_k
|
363 |
+
|
364 |
+
self.cfg = {
|
365 |
+
"name": "FaceBoxes",
|
366 |
+
#'min_dim': 1024,
|
367 |
+
#'feature_maps': [[32, 32], [16, 16], [8, 8]],
|
368 |
+
# 'aspect_ratios': [[1], [1], [1]],
|
369 |
+
"min_sizes": [[32, 64, 128], [256], [512]],
|
370 |
+
"steps": [32, 64, 128],
|
371 |
+
"variance": [0.1, 0.2],
|
372 |
+
"clip": False,
|
373 |
+
"loc_weight": 2.0,
|
374 |
+
"gpu_train": True,
|
375 |
+
}
|
376 |
+
|
377 |
+
self.cpu = False if torch.cuda.is_available() else True
|
378 |
+
torch.set_grad_enabled(False)
|
379 |
+
self.net = FaceBoxes(phase="test", size=None, num_classes=2)
|
380 |
+
self.net = load_model(self.net, path_to_weights, self.cpu)
|
381 |
+
self.net.eval()
|
382 |
+
self.device = torch.device("cpu" if self.cpu else "cuda")
|
383 |
+
self.net = self.net.to(self.device)
|
384 |
+
|
385 |
+
def detect_anime_face(self, image: np.ndarray) -> dict:
|
386 |
+
image = np.float32(image)
|
387 |
+
im_height, im_width, _ = image.shape
|
388 |
+
scale = torch.Tensor(
|
389 |
+
(image.shape[1], image.shape[0], image.shape[1], image.shape[0])
|
390 |
+
)
|
391 |
+
image -= (104, 117, 123)
|
392 |
+
image = image.transpose(2, 0, 1)
|
393 |
+
image = torch.from_numpy(image).unsqueeze(0)
|
394 |
+
start_time = time.perf_counter()
|
395 |
+
image = image.to(self.device)
|
396 |
+
end_time = time.perf_counter() - start_time
|
397 |
+
scale = scale.to(self.device)
|
398 |
+
|
399 |
+
loc, conf = self.net(image) # forward pass
|
400 |
+
priorbox = PriorBox(self.cfg, image_size=(im_height, im_width))
|
401 |
+
priors = priorbox.forward()
|
402 |
+
priors = priors.to(self.device)
|
403 |
+
prior_data = priors.data
|
404 |
+
boxes = decode(loc.data.squeeze(0), prior_data, self.cfg["variance"])
|
405 |
+
boxes = boxes * scale
|
406 |
+
boxes = boxes.cpu().numpy()
|
407 |
+
scores = conf.data.cpu().numpy()[:, 1]
|
408 |
+
|
409 |
+
# ignore low scores
|
410 |
+
inds = np.where(scores > self.confidence_threshold)[0]
|
411 |
+
boxes = boxes[inds]
|
412 |
+
scores = scores[inds]
|
413 |
+
|
414 |
+
# keep top-K before NMS
|
415 |
+
order = scores.argsort()[::-1][: self.top_k]
|
416 |
+
boxes = boxes[order]
|
417 |
+
scores = scores[order]
|
418 |
+
|
419 |
+
# do NMS
|
420 |
+
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
421 |
+
# keep = py_cpu_nms(dets, args.nms_threshold)
|
422 |
+
keep = nms(dets, self.nms_threshold, force_cpu=self.cpu)
|
423 |
+
dets = dets[keep, :]
|
424 |
+
|
425 |
+
# keep top-K faster NMS
|
426 |
+
dets = dets[: self.keep_top_k, :]
|
427 |
+
|
428 |
+
return_data = []
|
429 |
+
for k in range(dets.shape[0]):
|
430 |
+
xmin = dets[k, 0]
|
431 |
+
ymin = dets[k, 1]
|
432 |
+
xmax = dets[k, 2]
|
433 |
+
ymax = dets[k, 3]
|
434 |
+
ymin += 0.2 * (ymax - ymin + 1)
|
435 |
+
score = dets[k, 4]
|
436 |
+
return_data.append([xmin, ymin, xmax, ymax, score])
|
437 |
+
|
438 |
+
return {"anime_face": tuple(return_data), "inference_time": end_time}
|
439 |
+
|
440 |
+
|
441 |
+
if __name__ == "__main__":
|
442 |
+
model = SingleShotDetectorModel()
|
443 |
+
image = cv2.imread(
|
444 |
+
"../../assets/example_images/others/d29492bbe7604505a6f1b5394f62b393.png"
|
445 |
+
)
|
446 |
+
data = model.detect_anime_face(image)
|
447 |
+
for d in data:
|
448 |
+
cv2.rectangle(
|
449 |
+
image, (int(d[0]), int(d[1])), (int(d[2]), int(d[3])), (0, 255, 0), 2
|
450 |
+
)
|
451 |
+
print(data)
|
452 |
+
cv2.imshow("image", image)
|
453 |
+
cv2.waitKey(0)
|
454 |
+
cv2.destroyAllWindows()
|
models/base_model/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .grad_cam import BaseModelGradCAM
|
2 |
+
from .image_embeddings import BaseModelImageEmbeddings
|
3 |
+
from .image_similarity import BaseModelImageSimilarity
|
4 |
+
from .main_model import BaseModelMainModel
|
models/base_model/grad_cam.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
from pytorch_grad_cam import GradCAM
|
11 |
+
from pytorch_grad_cam.utils.image import show_cam_on_image
|
12 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
13 |
+
|
14 |
+
from utils import configs
|
15 |
+
from utils.functional import (
|
16 |
+
check_data_type_variable,
|
17 |
+
get_device,
|
18 |
+
image_augmentations,
|
19 |
+
normalize_image_to_zero_one,
|
20 |
+
reshape_transform,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class BaseModelGradCAM(ABC):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
name_model: str,
|
28 |
+
freeze_model: bool,
|
29 |
+
pretrained_model: bool,
|
30 |
+
support_set_method: str,
|
31 |
+
):
|
32 |
+
self.name_model = name_model
|
33 |
+
self.freeze_model = freeze_model
|
34 |
+
self.pretrained_model = pretrained_model
|
35 |
+
self.support_set_method = support_set_method
|
36 |
+
self.model = None
|
37 |
+
self.device = get_device()
|
38 |
+
|
39 |
+
self.check_arguments()
|
40 |
+
|
41 |
+
def check_arguments(self):
|
42 |
+
check_data_type_variable(self.name_model, str)
|
43 |
+
check_data_type_variable(self.freeze_model, bool)
|
44 |
+
check_data_type_variable(self.pretrained_model, bool)
|
45 |
+
check_data_type_variable(self.support_set_method, str)
|
46 |
+
|
47 |
+
old_name_model = self.name_model
|
48 |
+
if self.name_model == configs.CLIP_NAME_MODEL:
|
49 |
+
old_name_model = self.name_model
|
50 |
+
self.name_model = "clip"
|
51 |
+
if self.name_model not in tuple(configs.NAME_MODELS.keys()):
|
52 |
+
raise ValueError(f"Model {self.name_model} not supported")
|
53 |
+
if self.support_set_method not in configs.SUPPORT_SET_METHODS:
|
54 |
+
raise ValueError(
|
55 |
+
f"Support set method {self.support_set_method} not supported"
|
56 |
+
)
|
57 |
+
self.name_model = old_name_model
|
58 |
+
|
59 |
+
@abstractmethod
|
60 |
+
def init_model(self):
|
61 |
+
pass
|
62 |
+
|
63 |
+
def set_grad_cam(self):
|
64 |
+
if self.name_model == "resnet50":
|
65 |
+
self.target_layers = (self.model.model.layer4[-1],)
|
66 |
+
elif self.name_model == "vgg16":
|
67 |
+
self.target_layers = (self.model.model.features[-1],)
|
68 |
+
elif self.name_model == "inception_v4":
|
69 |
+
self.target_layers = (self.model.model.features[-1],)
|
70 |
+
elif self.name_model == "efficientnet_b4":
|
71 |
+
self.target_layers = (self.model.model.blocks[-1],)
|
72 |
+
elif self.name_model == "mobilenetv3_large_100":
|
73 |
+
self.target_layers = (self.model.model.blocks[-1],)
|
74 |
+
elif self.name_model == "densenet121":
|
75 |
+
self.target_layers = (self.model.model.features[-1],)
|
76 |
+
elif self.name_model == "vit_base_patch16_224_dino":
|
77 |
+
self.target_layers = (self.model.model.blocks[-1].norm1,)
|
78 |
+
elif self.name_model == "clip":
|
79 |
+
self.target_layers = (
|
80 |
+
self.model.vision_model.encoder.layers[-1].layer_norm1,
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
self.target_layers = (self.model.model.features[-1],)
|
84 |
+
|
85 |
+
if self.name_model in ("vit_base_patch16_224_dino", "clip"):
|
86 |
+
self.gradcam = GradCAM(
|
87 |
+
model=self.model,
|
88 |
+
target_layers=self.target_layers,
|
89 |
+
reshape_transform=reshape_transform,
|
90 |
+
use_cuda=True if self.device.type == "cuda" else False,
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
self.gradcam = GradCAM(
|
94 |
+
model=self.model,
|
95 |
+
target_layers=self.target_layers,
|
96 |
+
use_cuda=True if self.device.type == "cuda" else False,
|
97 |
+
)
|
98 |
+
|
99 |
+
def get_grad_cam(self, image: np.ndarray) -> np.ndarray:
|
100 |
+
image = np.array(
|
101 |
+
Image.fromarray(image).resize((configs.SIZE_IMAGES, configs.SIZE_IMAGES))
|
102 |
+
)
|
103 |
+
image_input = image_augmentations()(image=image)["image"]
|
104 |
+
image_input = image_input.unsqueeze(axis=0).to(self.device)
|
105 |
+
gradcam = self.gradcam(image_input)
|
106 |
+
gradcam = gradcam[0, :]
|
107 |
+
gradcam = show_cam_on_image(
|
108 |
+
normalize_image_to_zero_one(image), gradcam, use_rgb=True
|
109 |
+
)
|
110 |
+
return gradcam
|
111 |
+
|
112 |
+
def get_grad_cam_with_output_target(
|
113 |
+
self, image: np.ndarray, index_class: int
|
114 |
+
) -> np.ndarray:
|
115 |
+
image = np.array(
|
116 |
+
Image.fromarray(image).resize((configs.SIZE_IMAGES, configs.SIZE_IMAGES))
|
117 |
+
)
|
118 |
+
image_input = image_augmentations()(image=image)["image"]
|
119 |
+
image_input = image_input.unsqueeze(axis=0).to(self.device)
|
120 |
+
targets = (ClassifierOutputTarget(index_class),)
|
121 |
+
gradcam = self.gradcam(image_input, targets=targets)
|
122 |
+
gradcam = gradcam[0, :]
|
123 |
+
gradcam = show_cam_on_image(
|
124 |
+
normalize_image_to_zero_one(image), gradcam, use_rgb=True
|
125 |
+
)
|
126 |
+
return gradcam
|
models/base_model/image_embeddings.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
5 |
+
|
6 |
+
import time
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from utils import configs
|
13 |
+
from utils.functional import check_data_type_variable, get_device, image_augmentations
|
14 |
+
|
15 |
+
|
16 |
+
class BaseModelImageEmbeddings(ABC):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
name_model: str,
|
20 |
+
freeze_model: bool,
|
21 |
+
pretrained_model: bool,
|
22 |
+
support_set_method: str,
|
23 |
+
):
|
24 |
+
self.name_model = name_model
|
25 |
+
self.freeze_model = freeze_model
|
26 |
+
self.pretrained_model = pretrained_model
|
27 |
+
self.support_set_method = support_set_method
|
28 |
+
self.model = None
|
29 |
+
self.device = get_device()
|
30 |
+
|
31 |
+
self.check_arguments()
|
32 |
+
|
33 |
+
def check_arguments(self):
|
34 |
+
check_data_type_variable(self.name_model, str)
|
35 |
+
check_data_type_variable(self.freeze_model, bool)
|
36 |
+
check_data_type_variable(self.pretrained_model, bool)
|
37 |
+
check_data_type_variable(self.support_set_method, str)
|
38 |
+
|
39 |
+
old_name_model = self.name_model
|
40 |
+
if self.name_model == configs.CLIP_NAME_MODEL:
|
41 |
+
old_name_model = self.name_model
|
42 |
+
self.name_model = "clip"
|
43 |
+
if self.name_model not in tuple(configs.NAME_MODELS.keys()):
|
44 |
+
raise ValueError(f"Model {self.name_model} not supported")
|
45 |
+
if self.support_set_method not in configs.SUPPORT_SET_METHODS:
|
46 |
+
raise ValueError(
|
47 |
+
f"Support set method {self.support_set_method} not supported"
|
48 |
+
)
|
49 |
+
self.name_model = old_name_model
|
50 |
+
|
51 |
+
@abstractmethod
|
52 |
+
def init_model(self):
|
53 |
+
pass
|
54 |
+
|
55 |
+
def get_embeddings(self, image: np.ndarray) -> dict:
|
56 |
+
image_input = image_augmentations()(image=image)["image"]
|
57 |
+
image_input = image_input.unsqueeze(axis=0).to(self.device)
|
58 |
+
with torch.no_grad():
|
59 |
+
start_time = time.perf_counter()
|
60 |
+
embeddings = self.model(image_input)
|
61 |
+
end_time = time.perf_counter() - start_time
|
62 |
+
|
63 |
+
embeddings = embeddings.detach().cpu().numpy()
|
64 |
+
return {
|
65 |
+
"embeddings": embeddings,
|
66 |
+
"inference_time": end_time,
|
67 |
+
}
|
models/base_model/image_similarity.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
5 |
+
|
6 |
+
import time
|
7 |
+
from abc import ABC, abstractmethod
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from utils import configs
|
13 |
+
from utils.functional import (
|
14 |
+
check_data_type_variable,
|
15 |
+
euclidean_distance_normalized,
|
16 |
+
get_device,
|
17 |
+
image_augmentations,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
class BaseModelImageSimilarity(ABC):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
name_model: str,
|
25 |
+
freeze_model: bool,
|
26 |
+
pretrained_model: bool,
|
27 |
+
support_set_method: str,
|
28 |
+
):
|
29 |
+
self.name_model = name_model
|
30 |
+
self.freeze_model = freeze_model
|
31 |
+
self.pretrained_model = pretrained_model
|
32 |
+
self.support_set_method = support_set_method
|
33 |
+
self.model = None
|
34 |
+
self.device = get_device()
|
35 |
+
|
36 |
+
self.check_arguments()
|
37 |
+
|
38 |
+
def check_arguments(self):
|
39 |
+
check_data_type_variable(self.name_model, str)
|
40 |
+
check_data_type_variable(self.freeze_model, bool)
|
41 |
+
check_data_type_variable(self.pretrained_model, bool)
|
42 |
+
check_data_type_variable(self.support_set_method, str)
|
43 |
+
|
44 |
+
old_name_model = self.name_model
|
45 |
+
if self.name_model == configs.CLIP_NAME_MODEL:
|
46 |
+
old_name_model = self.name_model
|
47 |
+
self.name_model = "clip"
|
48 |
+
if self.name_model not in tuple(configs.NAME_MODELS.keys()):
|
49 |
+
raise ValueError(f"Model {self.name_model} not supported")
|
50 |
+
if self.support_set_method not in configs.SUPPORT_SET_METHODS:
|
51 |
+
raise ValueError(
|
52 |
+
f"Support set method {self.support_set_method} not supported"
|
53 |
+
)
|
54 |
+
self.name_model = old_name_model
|
55 |
+
|
56 |
+
@abstractmethod
|
57 |
+
def init_model(self):
|
58 |
+
pass
|
59 |
+
|
60 |
+
def get_similarity(self, image1: np.ndarray, image2: np.ndarray) -> dict:
|
61 |
+
image1_input = image_augmentations()(image=image1)["image"]
|
62 |
+
image2_input = image_augmentations()(image=image2)["image"]
|
63 |
+
|
64 |
+
image1_input = image1_input.unsqueeze(axis=0).to(self.device)
|
65 |
+
image2_input = image2_input.unsqueeze(axis=0).to(self.device)
|
66 |
+
|
67 |
+
with torch.no_grad():
|
68 |
+
start_time = time.perf_counter()
|
69 |
+
image1_input = self.model(image1_input)
|
70 |
+
image2_input = self.model(image2_input)
|
71 |
+
end_time = time.perf_counter() - start_time
|
72 |
+
|
73 |
+
image1_input = image1_input.detach().cpu().numpy()
|
74 |
+
image2_input = image2_input.detach().cpu().numpy()
|
75 |
+
similarity = euclidean_distance_normalized(image1_input, image2_input)
|
76 |
+
result_similarity = (
|
77 |
+
"same image"
|
78 |
+
if similarity
|
79 |
+
> configs.NAME_MODELS[self.name_model]["image_similarity_threshold"]
|
80 |
+
else "not same image"
|
81 |
+
)
|
82 |
+
return {
|
83 |
+
"similarity": similarity,
|
84 |
+
"result_similarity": result_similarity,
|
85 |
+
"inference_time": end_time,
|
86 |
+
}
|
models/base_model/main_model.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
from utils import configs
|
9 |
+
from utils.functional import check_data_type_variable, get_device
|
10 |
+
|
11 |
+
|
12 |
+
class BaseModelMainModel(ABC):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
name_model: str,
|
16 |
+
freeze_model: bool,
|
17 |
+
pretrained_model: bool,
|
18 |
+
support_set_method: str,
|
19 |
+
):
|
20 |
+
self.name_model = name_model
|
21 |
+
self.freeze_model = freeze_model
|
22 |
+
self.pretrained_model = pretrained_model
|
23 |
+
self.support_set_method = support_set_method
|
24 |
+
self.device = get_device()
|
25 |
+
|
26 |
+
self.check_arguments()
|
27 |
+
|
28 |
+
def check_arguments(self):
|
29 |
+
check_data_type_variable(self.name_model, str)
|
30 |
+
check_data_type_variable(self.freeze_model, bool)
|
31 |
+
check_data_type_variable(self.pretrained_model, bool)
|
32 |
+
check_data_type_variable(self.support_set_method, str)
|
33 |
+
|
34 |
+
old_name_model = self.name_model
|
35 |
+
if self.name_model == configs.CLIP_NAME_MODEL:
|
36 |
+
old_name_model = self.name_model
|
37 |
+
self.name_model = "clip"
|
38 |
+
if self.name_model not in tuple(configs.NAME_MODELS.keys()):
|
39 |
+
raise ValueError(f"Model {self.name_model} not supported")
|
40 |
+
if self.support_set_method not in configs.SUPPORT_SET_METHODS:
|
41 |
+
raise ValueError(
|
42 |
+
f"Support set method {self.support_set_method} not supported"
|
43 |
+
)
|
44 |
+
self.name_model = old_name_model
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def init_model(self):
|
48 |
+
pass
|
49 |
+
|
50 |
+
@abstractmethod
|
51 |
+
def predict(self):
|
52 |
+
pass
|
models/deep_learning/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .deep_learning import DeepLearningModel
|
2 |
+
from .grad_cam import DeepLearningGradCAM
|
3 |
+
from .image_embeddings import ImageEmbeddings
|
4 |
+
from .image_similarity import ImageSimilarity
|
models/deep_learning/backbone_model.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
5 |
+
|
6 |
+
import timm
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from transformers import CLIPModel as CLIPTransformersModel
|
10 |
+
|
11 |
+
from utils import configs
|
12 |
+
from utils.functional import check_data_type_variable, get_device
|
13 |
+
|
14 |
+
|
15 |
+
class CLIPModel(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
model_clip_name: str,
|
19 |
+
freeze_model: bool,
|
20 |
+
pretrained_model: bool,
|
21 |
+
num_classes: int,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.model_clip_name = model_clip_name
|
25 |
+
self.freeze_model = freeze_model
|
26 |
+
self.pretrained_model = pretrained_model
|
27 |
+
self.num_classes = num_classes
|
28 |
+
self.device = get_device()
|
29 |
+
|
30 |
+
self.check_arguments()
|
31 |
+
self.init_model()
|
32 |
+
|
33 |
+
def check_arguments(self):
|
34 |
+
check_data_type_variable(self.model_clip_name, str)
|
35 |
+
check_data_type_variable(self.freeze_model, bool)
|
36 |
+
check_data_type_variable(self.pretrained_model, bool)
|
37 |
+
check_data_type_variable(self.num_classes, int)
|
38 |
+
|
39 |
+
if self.model_clip_name != configs.CLIP_NAME_MODEL:
|
40 |
+
raise ValueError(
|
41 |
+
f"Model clip name must be {configs.CLIP_NAME_MODEL}, but it is {self.model_clip_name}"
|
42 |
+
)
|
43 |
+
|
44 |
+
def init_model(self):
|
45 |
+
clip_model = CLIPTransformersModel.from_pretrained(self.model_clip_name)
|
46 |
+
for layer in clip_model.children():
|
47 |
+
if hasattr(layer, "reset_parameters") and not self.pretrained_model:
|
48 |
+
layer.reset_parameters()
|
49 |
+
for param in clip_model.parameters():
|
50 |
+
param.required_grad = False if not self.freeze_model else True
|
51 |
+
self.vision_model = clip_model.vision_model.to(self.device)
|
52 |
+
self.visual_projection = clip_model.visual_projection.to(self.device).to(
|
53 |
+
self.device
|
54 |
+
)
|
55 |
+
self.classifier = nn.Linear(
|
56 |
+
512, 1 if self.num_classes in (1, 2) else self.num_classes
|
57 |
+
).to(self.device)
|
58 |
+
|
59 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
60 |
+
x = self.vision_model(x)
|
61 |
+
x = self.visual_projection(x.pooler_output)
|
62 |
+
x = self.classifier(x)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class TorchModel(nn.Module):
|
67 |
+
def __init__(
|
68 |
+
self,
|
69 |
+
name_model: str,
|
70 |
+
freeze_model: bool,
|
71 |
+
pretrained_model: bool,
|
72 |
+
num_classes: int,
|
73 |
+
):
|
74 |
+
super().__init__()
|
75 |
+
self.name_model = name_model
|
76 |
+
self.freeze_model = freeze_model
|
77 |
+
self.pretrained_model = pretrained_model
|
78 |
+
self.num_classes = num_classes
|
79 |
+
self.device = get_device()
|
80 |
+
|
81 |
+
self.check_arguments()
|
82 |
+
self.init_model()
|
83 |
+
|
84 |
+
def check_arguments(self):
|
85 |
+
check_data_type_variable(self.name_model, str)
|
86 |
+
check_data_type_variable(self.freeze_model, bool)
|
87 |
+
check_data_type_variable(self.pretrained_model, bool)
|
88 |
+
check_data_type_variable(self.num_classes, int)
|
89 |
+
|
90 |
+
if self.name_model not in tuple(configs.NAME_MODELS.keys()):
|
91 |
+
raise ValueError(
|
92 |
+
f"Name model must be in {tuple(configs.NAME_MODELS.keys())}, but it is {self.name_model}"
|
93 |
+
)
|
94 |
+
|
95 |
+
def init_model(self):
|
96 |
+
self.model = timm.create_model(
|
97 |
+
self.name_model, pretrained=self.pretrained_model, num_classes=0
|
98 |
+
).to(self.device)
|
99 |
+
for param in self.model.parameters():
|
100 |
+
param.required_grad = False if not self.freeze_model else True
|
101 |
+
self.classifier = nn.Linear(
|
102 |
+
self.model.num_features,
|
103 |
+
1 if self.num_classes in (1, 2) else self.num_classes,
|
104 |
+
).to(self.device)
|
105 |
+
|
106 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
107 |
+
x = self.model(x)
|
108 |
+
x = self.classifier(x)
|
109 |
+
return x
|
models/deep_learning/deep_learning.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
5 |
+
|
6 |
+
import time
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from PIL import Image
|
11 |
+
|
12 |
+
from models.base_model import BaseModelMainModel
|
13 |
+
from utils import configs
|
14 |
+
from utils.functional import image_augmentations, active_learning_uncertainty
|
15 |
+
|
16 |
+
from .lightning_module import ImageClassificationLightningModule
|
17 |
+
|
18 |
+
|
19 |
+
class DeepLearningModel(BaseModelMainModel):
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
name_model: str,
|
23 |
+
freeze_model: bool,
|
24 |
+
pretrained_model: bool,
|
25 |
+
support_set_method: str,
|
26 |
+
):
|
27 |
+
super().__init__(name_model, freeze_model, pretrained_model, support_set_method)
|
28 |
+
self.init_model()
|
29 |
+
|
30 |
+
def init_model(self):
|
31 |
+
self.model = ImageClassificationLightningModule.load_from_checkpoint(
|
32 |
+
os.path.join(
|
33 |
+
configs.WEIGHTS_PATH,
|
34 |
+
self.name_model,
|
35 |
+
self.support_set_method,
|
36 |
+
"best.ckpt",
|
37 |
+
),
|
38 |
+
name_model=self.name_model,
|
39 |
+
freeze_model=self.freeze_model,
|
40 |
+
pretrained_model=self.pretrained_model,
|
41 |
+
)
|
42 |
+
self.model = self.model.model
|
43 |
+
for layer in self.model.children():
|
44 |
+
if hasattr(layer, "reset_parameters") and not self.pretrained_model:
|
45 |
+
layer.reset_parameters()
|
46 |
+
for param in self.model.parameters():
|
47 |
+
param.required_grad = False if not self.freeze_model else True
|
48 |
+
self.model.to(self.device)
|
49 |
+
self.model.eval()
|
50 |
+
|
51 |
+
def predict(self, image: np.ndarray) -> dict:
|
52 |
+
image_input = image_augmentations()(image=image)["image"]
|
53 |
+
image_input = image_input.unsqueeze(axis=0).to(self.device)
|
54 |
+
with torch.no_grad():
|
55 |
+
start_time = time.perf_counter()
|
56 |
+
result = self.model(image_input)
|
57 |
+
end_time = time.perf_counter() - start_time
|
58 |
+
result = torch.softmax(result, dim=1)
|
59 |
+
result = result.detach().cpu().numpy()
|
60 |
+
result_index = np.argmax(result)
|
61 |
+
confidence = result[0][result_index]
|
62 |
+
uncertainty_score = active_learning_uncertainty(result[0])
|
63 |
+
uncertainty_score = uncertainty_score if uncertainty_score > 0 else 0
|
64 |
+
if (
|
65 |
+
uncertainty_score
|
66 |
+
> configs.NAME_MODELS[self.name_model][
|
67 |
+
"deep_learning_out_of_distribution_threshold"
|
68 |
+
][self.support_set_method]
|
69 |
+
):
|
70 |
+
return {
|
71 |
+
"character": configs.CLASS_CHARACTERS[-1],
|
72 |
+
"confidence": confidence,
|
73 |
+
"inference_time": end_time,
|
74 |
+
}
|
75 |
+
return {
|
76 |
+
"character": configs.CLASS_CHARACTERS[result_index],
|
77 |
+
"confidence": confidence,
|
78 |
+
"inference_time": end_time,
|
79 |
+
}
|
80 |
+
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
model = DeepLearningModel("resnet50", True, True, "1_shot")
|
84 |
+
image = np.array(
|
85 |
+
Image.open(
|
86 |
+
"../../assets/example_images/gon/306e5d35-b301-4299-8022-0c89dc0b7690.png"
|
87 |
+
).convert("RGB")
|
88 |
+
)
|
89 |
+
result = model.predict(image)
|
90 |
+
print(result)
|
models/deep_learning/grad_cam.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from models.base_model import BaseModelGradCAM
|
11 |
+
from utils import configs
|
12 |
+
|
13 |
+
from .lightning_module import ImageClassificationLightningModule
|
14 |
+
|
15 |
+
|
16 |
+
class DeepLearningGradCAM(BaseModelGradCAM):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
name_model: str,
|
20 |
+
freeze_model: bool,
|
21 |
+
pretrained_model: bool,
|
22 |
+
support_set_method: str,
|
23 |
+
):
|
24 |
+
super().__init__(name_model, freeze_model, pretrained_model, support_set_method)
|
25 |
+
self.init_model()
|
26 |
+
self.set_grad_cam()
|
27 |
+
|
28 |
+
def init_model(self):
|
29 |
+
self.model = ImageClassificationLightningModule.load_from_checkpoint(
|
30 |
+
os.path.join(
|
31 |
+
configs.WEIGHTS_PATH,
|
32 |
+
self.name_model,
|
33 |
+
self.support_set_method,
|
34 |
+
"best.ckpt",
|
35 |
+
),
|
36 |
+
name_model=self.name_model,
|
37 |
+
freeze_model=self.freeze_model,
|
38 |
+
pretrained_model=self.pretrained_model,
|
39 |
+
)
|
40 |
+
self.model = self.model.model
|
41 |
+
for layer in self.model.children():
|
42 |
+
if hasattr(layer, "reset_parameters") and not self.pretrained_model:
|
43 |
+
layer.reset_parameters()
|
44 |
+
for param in self.model.parameters():
|
45 |
+
param.required_grad = False if not self.freeze_model else True
|
46 |
+
self.model.to(self.device)
|
47 |
+
self.model.eval()
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
model = DeepLearningGradCAM("resnet50", False, True, "5_shot")
|
52 |
+
image = np.array(
|
53 |
+
Image.open(
|
54 |
+
"../../assets/example_images/gon/306e5d35-b301-4299-8022-0c89dc0b7690.png"
|
55 |
+
).convert("RGB")
|
56 |
+
)
|
57 |
+
gradcam = model.get_grad_cam(image)
|
58 |
+
plt.imshow(gradcam)
|
59 |
+
plt.show()
|
models/deep_learning/image_embeddings.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch.nn as nn
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from models.base_model import BaseModelImageEmbeddings
|
11 |
+
from utils import configs
|
12 |
+
|
13 |
+
from .lightning_module import ImageClassificationLightningModule
|
14 |
+
|
15 |
+
|
16 |
+
class ImageEmbeddings(BaseModelImageEmbeddings):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
name_model: str,
|
20 |
+
freeze_model: bool,
|
21 |
+
pretrained_model: bool,
|
22 |
+
support_set_method: str,
|
23 |
+
):
|
24 |
+
super().__init__(name_model, freeze_model, pretrained_model, support_set_method)
|
25 |
+
self.init_model()
|
26 |
+
|
27 |
+
def init_model(self):
|
28 |
+
self.model = ImageClassificationLightningModule.load_from_checkpoint(
|
29 |
+
os.path.join(
|
30 |
+
configs.WEIGHTS_PATH,
|
31 |
+
self.name_model,
|
32 |
+
self.support_set_method,
|
33 |
+
"best.ckpt",
|
34 |
+
),
|
35 |
+
name_model=self.name_model,
|
36 |
+
freeze_model=self.freeze_model,
|
37 |
+
pretrained_model=self.pretrained_model,
|
38 |
+
)
|
39 |
+
self.model = self.model.model
|
40 |
+
self.model.classifier = nn.Identity()
|
41 |
+
for layer in self.model.children():
|
42 |
+
if hasattr(layer, "reset_parameters") and not self.pretrained_model:
|
43 |
+
layer.reset_parameters()
|
44 |
+
for param in self.model.parameters():
|
45 |
+
param.required_grad = False if not self.freeze_model else True
|
46 |
+
self.model.to(self.device)
|
47 |
+
self.model.eval()
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
model = ImageEmbeddings("resnet50", True, True, "1_shot")
|
52 |
+
image = np.array(
|
53 |
+
Image.open(
|
54 |
+
"../../assets/example_images/gon/306e5d35-b301-4299-8022-0c89dc0b7690.png"
|
55 |
+
).convert("RGB")
|
56 |
+
)
|
57 |
+
result = model.get_embeddings(image)
|
58 |
+
print(result)
|
models/deep_learning/image_similarity.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch.nn as nn
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from models.base_model import BaseModelImageSimilarity
|
11 |
+
from utils import configs
|
12 |
+
|
13 |
+
from .lightning_module import ImageClassificationLightningModule
|
14 |
+
|
15 |
+
|
16 |
+
class ImageSimilarity(BaseModelImageSimilarity):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
name_model: str,
|
20 |
+
freeze_model: bool,
|
21 |
+
pretrained_model: bool,
|
22 |
+
support_set_method: str,
|
23 |
+
):
|
24 |
+
super().__init__(name_model, freeze_model, pretrained_model, support_set_method)
|
25 |
+
self.init_model()
|
26 |
+
|
27 |
+
def init_model(self):
|
28 |
+
self.model = ImageClassificationLightningModule.load_from_checkpoint(
|
29 |
+
os.path.join(
|
30 |
+
configs.WEIGHTS_PATH,
|
31 |
+
self.name_model,
|
32 |
+
self.support_set_method,
|
33 |
+
"best.ckpt",
|
34 |
+
),
|
35 |
+
name_model=self.name_model,
|
36 |
+
freeze_model=self.freeze_model,
|
37 |
+
pretrained_model=self.pretrained_model,
|
38 |
+
)
|
39 |
+
self.model = self.model.model
|
40 |
+
self.model.classifier = nn.Identity()
|
41 |
+
for layer in self.model.children():
|
42 |
+
if hasattr(layer, "reset_parameters") and not self.pretrained_model:
|
43 |
+
layer.reset_parameters()
|
44 |
+
for param in self.model.parameters():
|
45 |
+
param.required_grad = False if not self.freeze_model else True
|
46 |
+
self.model.to(self.device)
|
47 |
+
self.model.eval()
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
model = ImageSimilarity("resnet50", True, True, "1_shot")
|
52 |
+
image1 = np.array(
|
53 |
+
Image.open(
|
54 |
+
"../../assets/example_images/gon/306e5d35-b301-4299-8022-0c89dc0b7690.png"
|
55 |
+
).convert("RGB")
|
56 |
+
)
|
57 |
+
image2 = np.array(
|
58 |
+
Image.open(
|
59 |
+
"../../assets/example_images/gon/306e5d35-b301-4299-8022-0c89dc0b7690.png"
|
60 |
+
).convert("RGB")
|
61 |
+
)
|
62 |
+
result = model.get_similarity(image1, image2)
|
63 |
+
print(result)
|