Upload 9 files
Browse files- .gitignore +7 -0
- README.md +135 -3
- environment.yaml +313 -0
- get_data.ps1 +11 -0
- get_data.sh +7 -0
- get_samples.py +55 -0
- indices_60_32.pkl +3 -0
- inference.py +70 -0
- train.py +174 -0
@@ -0,0 +1,7 @@
1 |
2 |
3 |
4 |
5 |
6 |
7 |
@@ -1,3 +1,135 @@
1 |
2 |
3 |
1 |
# PokemonClassification
2 |
3 |
This repository explores the training of different models for a vision classification task, with a special focus made on reproducibility, and an attempt to a local interpretability of the decision made by a resnet model using LIME
4 |
5 |
## Table of Contents
6 |
7 |
- [PokemonClassification](#pokemonclassification)
8 |
- [Table of Contents](#table-of-contents)
9 |
- [Installation](#installation)
10 |
- [Dataset](#dataset)
11 |
- [Training](#training)
12 |
- [Inference](#inference)
13 |
- [Generating Data Samples](#generating-data-samples)
14 |
- [Interpretability](#interpretability)
15 |
- [Contributing](#contributing)
16 |
17 |
## Installation
18 |
19 |
1. Clone the repository:
20 |
21 |
22 |
git clone https://github.com/yourusername/PokemonClassification.git
23 |
cd PokemonClassification
24 |
25 |
26 |
2. Create a conda environment and activate it:
27 |
28 |
29 |
conda env create -f environment.yaml
30 |
conda activate pokemonclassification
31 |
32 |
33 |
## Dataset
34 |
35 |
To get the data, use the appropriate script based on your operating system:
36 |
37 |
- On Linux-based systems:
38 |
39 |
40 |
41 |
42 |
43 |
- On Windows:
44 |
45 |
46 |
47 |
48 |
49 |
## Training
50 |
51 |
To train a model, use the `train.py` script. Here are the parameters you can specify:
52 |
53 |
54 |
def parser_args():
55 |
parser = argparse.ArgumentParser(description="Pokemon Classification")
56 |
parser.add_argument("--data_dir", type=str, default="./pokemonclassification/PokemonData", help="Path to the data directory")
57 |
parser.add_argument("--indices_file", type=str, default="indices_60_32.pkl", help="Path to the indices file")
58 |
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
59 |
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
60 |
parser.add_argument("--train_batch_size", type=int, default=128, help="train Batch size")
61 |
parser.add_argument("--test_batch_size", type=int, default=512, help="test Batch size")
62 |
parser.add_argument("--model", type=str, choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"], default="resnet", help="Model to be used")
63 |
parser.add_argument("--feature_extract", type=bool, default=True, help="whether to freeze the backbone or not")
64 |
parser.add_argument("--use_pretrained", type=bool, default=True, help="whether to use pretrained model or not")
65 |
parser.add_argument("--experiment_id", type=int, default=0, help="Experiment ID to log the results")
66 |
return parser.parse_args()
67 |
68 |
69 |
70 |
71 |
72 |
python train.py--model resnet --data_dir data/PokemonData --epochs 10 --train_batch_size 32 --test_batch_size 32
73 |
74 |
75 |
## Inference
76 |
77 |
To perform inference on a single image, use the `inference.py` script. Here are the parameters you can specify:
78 |
79 |
80 |
def main():
81 |
parser = argparse.ArgumentParser(description="Image Inference")
82 |
parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
83 |
parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
84 |
parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/Chansey/57ccf27cba024fac9531baa9f619ec62.jpg")
85 |
parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
86 |
parser.add_argument("--lime_interpretability", action="store_true", help="Whether to run interpretability or not")
87 |
parser.add_argument("--classify", action="store_true", help="Whether to classify the image when saving the lime filter")
88 |
args = parser.parse_args()
89 |
90 |
if args.lime_interpretability:
91 |
assert args.model_name == "resnet", "Interpretability is only supported for ResNet model for now"
92 |
93 |
94 |
95 |
96 |
97 |
python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10
98 |
99 |
100 |
## Generating Data Samples
101 |
102 |
To generate data samples, use the `get_samples.py` script. Here are the parameters you can specify:
103 |
104 |
105 |
def main():
106 |
parser = argparse.ArgumentParser(description="Generate Data Samples")
107 |
parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
108 |
parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
109 |
parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/")
110 |
parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
111 |
parser.add_argument("--label", type=str, help="Label to filter the images", default='Dragonair')
112 |
parser.add_argument("--num_correct", type=int, help="Number of correctly classified images", default=5)
113 |
parser.add_argument("--num_incorrect", type=int, help="Number of incorrectly classified images", default=5)
114 |
args = parser.parse_args()
115 |
116 |
117 |
118 |
119 |
120 |
python get_samples.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image_directory --num_classes 10 --label Pikachu --num_correct 5 --num_incorrect 5
121 |
122 |
123 |
## Interpretability
124 |
125 |
To interpret the model's predictions using LIME, use the `inference.py` script with the `--lime_interpretability` flag.
126 |
127 |
128 |
129 |
130 |
python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10 --lime_interpretability
131 |
132 |
133 |
## Contributing
134 |
135 |
Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.
@@ -0,0 +1,313 @@
1 |
name: cloudspace
2 |
3 |
- conda-forge
4 |
- defaults
5 |
6 |
- _libgcc_mutex=0.1=main
7 |
- _openmp_mutex=5.1=1_gnu
8 |
- aiohappyeyeballs=2.4.3=py310h06a4308_0
9 |
- alembic=1.13.3=py310h06a4308_0
10 |
- aniso8601=9.0.1=pyhd3eb1b0_0
11 |
- arrow-cpp=16.1.0=hc1eb8f0_0
12 |
- attrs=24.2.0=py310h06a4308_0
13 |
- bcrypt=3.2.0=py310h5eee18b_1
14 |
- blas=1.0=openblas
15 |
- blinker=1.6.2=py310h06a4308_0
16 |
- boost-cpp=1.82.0=hdb19cb5_2
17 |
- bottleneck=1.4.2=py310ha9d4c09_0
18 |
- brotli=1.0.9=h5eee18b_8
19 |
- brotli-bin=1.0.9=h5eee18b_8
20 |
- brotli-python=1.0.9=py310h6a678d5_8
21 |
- bzip2=1.0.8=h5eee18b_6
22 |
- c-ares=1.19.1=h5eee18b_0
23 |
- ca-certificates=2024.11.26=h06a4308_0
24 |
- certifi=2024.8.30=py310h06a4308_0
25 |
- cffi=1.17.1=py310h1fdaa30_0
26 |
- click=8.1.7=py310h06a4308_0
27 |
- cloudpickle=3.0.0=py310h06a4308_0
28 |
- contourpy=1.3.1=py310hdb19cb5_0
29 |
- cryptography=43.0.3=py310h7825ff9_1
30 |
- databricks-sdk=0.33.0=py310h06a4308_0
31 |
- deprecated=1.2.13=py310h06a4308_0
32 |
- docker-py=7.1.0=py310h06a4308_0
33 |
- entrypoints=0.4=py310h06a4308_0
34 |
- flask=3.0.3=py310h06a4308_0
35 |
- freetype=2.12.1=h4a9f257_0
36 |
- frozenlist=1.5.0=py310h5eee18b_0
37 |
- gflags=2.2.2=h6a678d5_1
38 |
- gitdb=4.0.7=pyhd3eb1b0_0
39 |
- gitpython=3.1.43=py310h06a4308_0
40 |
- glog=0.5.0=h6a678d5_1
41 |
- graphene=3.3=py310h06a4308_0
42 |
- graphql-core=3.2.3=py310h06a4308_1
43 |
- graphql-relay=3.2.0=py310h06a4308_0
44 |
- greenlet=3.0.1=py310h6a678d5_0
45 |
- gunicorn=22.0.0=py310h06a4308_0
46 |
- icu=73.1=h6a678d5_0
47 |
- importlib-metadata=8.5.0=py310h06a4308_0
48 |
- itsdangerous=2.2.0=py310h06a4308_0
49 |
- jinja2=3.1.4=py310h06a4308_1
50 |
- joblib=1.4.2=py310h06a4308_0
51 |
- jpeg=9e=h5eee18b_3
52 |
- krb5=1.20.1=h143b758_1
53 |
- lcms2=2.12=h3be6417_0
54 |
- ld_impl_linux-64=2.40=h12ee557_0
55 |
- lerc=3.0=h295c915_0
56 |
- libabseil=20240116.2=cxx17_h6a678d5_0
57 |
- libboost=1.82.0=h109eef0_2
58 |
- libbrotlicommon=1.0.9=h5eee18b_8
59 |
- libbrotlidec=1.0.9=h5eee18b_8
60 |
- libbrotlienc=1.0.9=h5eee18b_8
61 |
- libcurl=8.9.1=h251f7ec_0
62 |
- libdeflate=1.17=h5eee18b_1
63 |
- libedit=3.1.20230828=h5eee18b_0
64 |
- libev=4.33=h7f8727e_1
65 |
- libevent=2.1.12=hdbd6064_1
66 |
- libffi=3.4.4=h6a678d5_1
67 |
- libgcc-ng=11.2.0=h1234567_1
68 |
- libgfortran-ng=11.2.0=h00389a5_1
69 |
- libgfortran5=11.2.0=h1234567_1
70 |
- libgomp=11.2.0=h1234567_1
71 |
- libgrpc=1.62.2=h2d74bed_0
72 |
- libnghttp2=1.57.0=h2d74bed_0
73 |
- libopenblas=0.3.21=h043d6bf_0
74 |
- libpng=1.6.39=h5eee18b_0
75 |
- libprotobuf=4.25.3=he621ea3_0
76 |
- libsodium=1.0.18=h7b6447c_0
77 |
- libssh2=1.11.1=h251f7ec_0
78 |
- libstdcxx-ng=11.2.0=h1234567_1
79 |
- libthrift=0.15.0=h1795dd8_2
80 |
- libtiff=4.5.1=h6a678d5_0
81 |
- libuuid=1.41.5=h5eee18b_0
82 |
- libwebp-base=1.3.2=h5eee18b_1
83 |
- lz4-c=1.9.4=h6a678d5_1
84 |
- mako=1.2.3=py310h06a4308_0
85 |
- matplotlib-base=3.9.2=py310hbfdbfaf_1
86 |
- mlflow=2.18.0=hff52083_0
87 |
- mlflow-skinny=2.18.0=py310hff52083_0
88 |
- mlflow-ui=2.18.0=py310hff52083_0
89 |
- multidict=6.1.0=py310h5eee18b_0
90 |
- ncurses=6.4=h6a678d5_0
91 |
- numexpr=2.10.1=py310hd28fd6d_0
92 |
- numpy=1.26.4=py310heeff2f4_0
93 |
- numpy-base=1.26.4=py310h8a23956_0
94 |
- openjpeg=2.5.2=he7f1fd0_0
95 |
- openssl=3.0.15=h5eee18b_0
96 |
- opentelemetry-api=1.16.0=pyhd8ed1ab_0
97 |
- opentelemetry-sdk=1.16.0=pyhd8ed1ab_0
98 |
- opentelemetry-semantic-conventions=0.37b0=pyhd8ed1ab_0
99 |
- orc=2.0.1=h2d29ad5_0
100 |
- paramiko=3.5.0=py310h06a4308_0
101 |
- pillow=11.0.0=py310hfdbf927_0
102 |
- prometheus_client=0.21.0=py310h06a4308_0
103 |
- prometheus_flask_exporter=0.22.4=py310h06a4308_0
104 |
- propcache=0.2.0=py310h5eee18b_0
105 |
- pyarrow=16.1.0=py310h1128e8f_0
106 |
- pynacl=1.5.0=py310h5eee18b_0
107 |
- pyopenssl=24.2.1=py310h06a4308_0
108 |
- pyparsing=3.2.0=py310h06a4308_0
109 |
- pysocks=1.7.1=py310h06a4308_0
110 |
- python=3.10.15=he870216_1
111 |
- python-dateutil=2.9.0post0=py310h06a4308_2
112 |
- python-tzdata=2023.3=pyhd3eb1b0_0
113 |
- python_abi=3.10=2_cp310
114 |
- pyyaml=6.0.2=py310h5eee18b_0
115 |
- querystring_parser=1.2.4=py310h06a4308_0
116 |
- re2=2022.04.01=h295c915_0
117 |
- readline=8.2=h5eee18b_0
118 |
- requests=2.32.3=py310h06a4308_1
119 |
- s2n=1.3.27=hdbd6064_0
120 |
- setuptools=75.1.0=py310h06a4308_0
121 |
- six=1.16.0=pyhd3eb1b0_1
122 |
- smmap=4.0.0=pyhd3eb1b0_0
123 |
- snappy=1.2.1=h6a678d5_0
124 |
- sqlalchemy=2.0.34=py310h00e1ef3_0
125 |
- sqlite=3.45.3=h5eee18b_0
126 |
- sqlparse=0.4.4=py310h06a4308_0
127 |
- threadpoolctl=3.5.0=py310h2f386ee_0
128 |
- tk=8.6.14=h39e8969_0
129 |
- typing_extensions=4.11.0=py310h06a4308_0
130 |
- unicodedata2=15.1.0=py310h5eee18b_0
131 |
- urllib3=2.2.3=py310h06a4308_0
132 |
- utf8proc=2.6.1=h5eee18b_1
133 |
- websocket-client=1.8.0=py310h06a4308_0
134 |
- wheel=0.44.0=py310h06a4308_0
135 |
- wrapt=1.14.1=py310h5eee18b_0
136 |
- xz=5.4.6=h5eee18b_1
137 |
- yaml=0.2.5=h7b6447c_0
138 |
- yarl=1.18.0=py310h5eee18b_0
139 |
- zipp=3.21.0=py310h06a4308_0
140 |
- zlib=1.2.13=h5eee18b_1
141 |
- zstd=1.5.6=hc292b87_0
142 |
- pip:
143 |
- absl-py==2.1.0
144 |
- aiohttp==3.11.7
145 |
- aiosignal==1.3.1
146 |
- annotated-types==0.7.0
147 |
- anyio==4.6.2.post1
148 |
- argon2-cffi==23.1.0
149 |
- argon2-cffi-bindings==21.2.0
150 |
- arrow==1.3.0
151 |
- asttokens==2.4.1
152 |
- async-lru==2.0.4
153 |
- async-timeout==5.0.1
154 |
- babel==2.16.0
155 |
- backoff==2.2.1
156 |
- beautifulsoup4==4.12.3
157 |
- bleach==6.2.0
158 |
- boto3==1.35.70
159 |
- botocore==1.35.70
160 |
- cachetools==5.5.0
161 |
- charset-normalizer==3.4.0
162 |
- comm==0.2.2
163 |
- cycler==0.12.1
164 |
- debugpy==1.8.9
165 |
- decorator==5.1.1
166 |
- defusedxml==0.7.1
167 |
- exceptiongroup==1.2.2
168 |
- executing==2.1.0
169 |
- fastapi==0.115.5
170 |
- fastjsonschema==2.20.0
171 |
- filelock==3.16.1
172 |
- fire==0.7.0
173 |
- fonttools==4.55.0
174 |
- fqdn==1.5.1
175 |
- fsspec==2024.10.0
176 |
- git-filter-repo==2.47.0
177 |
- google-auth==2.36.0
178 |
- google-auth-oauthlib==1.2.1
179 |
- grpcio==1.68.0
180 |
- h11==0.14.0
181 |
- httpcore==1.0.7
182 |
- httptools==0.6.4
183 |
- httpx==0.27.2
184 |
- idna==3.10
185 |
- imageio==2.36.1
186 |
- ipykernel==6.26.0
187 |
- ipython==8.17.2
188 |
- ipywidgets==8.1.1
189 |
- isoduration==20.11.0
190 |
- jedi==0.19.2
191 |
- jmespath==1.0.1
192 |
- json5==0.10.0
193 |
- jsonpointer==3.0.0
194 |
- jsonschema==4.23.0
195 |
- jsonschema-specifications==2024.10.1
196 |
- jupyter-client==8.6.3
197 |
- jupyter-core==5.7.2
198 |
- jupyter-events==0.10.0
199 |
- jupyter-lsp==2.2.5
200 |
- jupyter-server==2.14.2
201 |
- jupyter-server-terminals==0.5.3
202 |
- jupyterlab==4.2.0
203 |
- jupyterlab-pygments==0.3.0
204 |
- jupyterlab-server==2.27.3
205 |
- jupyterlab-widgets==3.0.13
206 |
- kiwisolver==1.4.7
207 |
- lazy-loader==0.4
208 |
- lightning==2.4.0
209 |
- lightning-cloud==0.5.70
210 |
- lightning-sdk==0.1.30
211 |
- lightning-utilities==0.11.9
212 |
- lime==
213 |
- litdata==0.2.32
214 |
- litserve==0.2.5
215 |
- markdown==3.7
216 |
- markdown-it-py==3.0.0
217 |
- markupsafe==3.0.2
218 |
- matplotlib==3.8.2
219 |
- matplotlib-inline==0.1.7
220 |
- mdurl==0.1.2
221 |
- mistune==3.0.2
222 |
- mpmath==1.3.0
223 |
- nbclient==0.10.0
224 |
- nbconvert==7.16.4
225 |
- nbformat==5.10.4
226 |
- nest-asyncio==1.6.0
227 |
- networkx==3.4.2
228 |
- notebook-shim==0.2.4
229 |
- nvidia-cublas-cu12==
230 |
- nvidia-cuda-cupti-cu12==12.1.105
231 |
- nvidia-cuda-nvrtc-cu12==12.1.105
232 |
- nvidia-cuda-runtime-cu12==12.1.105
233 |
- nvidia-cudnn-cu12==
234 |
- nvidia-cufft-cu12==
235 |
- nvidia-curand-cu12==
236 |
- nvidia-cusolver-cu12==
237 |
- nvidia-cusparse-cu12==
238 |
- nvidia-nccl-cu12==2.19.3
239 |
- nvidia-nvjitlink-cu12==12.6.85
240 |
- nvidia-nvtx-cu12==12.1.105
241 |
- oauthlib==3.2.2
242 |
- overrides==7.7.0
243 |
- packaging==24.2
244 |
- pandas==2.1.4
245 |
- pandocfilters==1.5.1
246 |
- parso==0.8.4
247 |
- pexpect==4.9.0
248 |
- pip==24.3.1
249 |
- platformdirs==4.3.6
250 |
- prompt-toolkit==3.0.48
251 |
- protobuf==4.23.4
252 |
- psutil==6.1.0
253 |
- ptyprocess==0.7.0
254 |
- pure-eval==0.2.3
255 |
- pyasn1==0.6.1
256 |
- pyasn1-modules==0.4.1
257 |
- pycparser==2.22
258 |
- pydantic==2.10.2
259 |
- pydantic-core==2.27.1
260 |
- pygments==2.18.0
261 |
- pyjwt==2.10.0
262 |
- python-dotenv==1.0.1
263 |
- python-json-logger==2.0.7
264 |
- python-multipart==0.0.17
265 |
- pytorch-lightning==2.4.0
266 |
- pytz==2024.2
267 |
- pyzmq==26.2.0
268 |
- referencing==0.35.1
269 |
- requests-oauthlib==2.0.0
270 |
- rfc3339-validator==0.1.4
271 |
- rfc3986-validator==0.1.1
272 |
- rich==13.9.4
273 |
- rpds-py==0.21.0
274 |
- rsa==4.9
275 |
- s3transfer==0.10.4
276 |
- scikit-image==0.24.0
277 |
- scikit-learn==1.3.2
278 |
- scipy==1.11.4
279 |
- send2trash==1.8.3
280 |
- simple-term-menu==1.6.5
281 |
- sniffio==1.3.1
282 |
- soupsieve==2.6
283 |
- stack-data==0.6.3
284 |
- starlette==0.41.3
285 |
- sympy==1.13.3
286 |
- tensorboard==2.15.1
287 |
- tensorboard-data-server==0.7.2
288 |
- termcolor==2.5.0
289 |
- terminado==0.18.1
290 |
- tifffile==2024.9.20
291 |
- tinycss2==1.4.0
292 |
- tomli==2.1.0
293 |
- torch==2.2.1+cu121
294 |
- torchmetrics==1.3.1
295 |
- torchvision==0.17.1+cu121
296 |
- tornado==6.4.2
297 |
- tqdm==4.67.1
298 |
- traitlets==5.14.3
299 |
- triton==2.2.0
300 |
- types-python-dateutil==
301 |
- typing-extensions==4.12.2
302 |
- tzdata==2024.2
303 |
- uri-template==1.3.0
304 |
- uvicorn==0.32.1
305 |
- uvloop==0.21.0
306 |
- watchfiles==1.0.0
307 |
- wcwidth==0.2.13
308 |
- webcolors==24.11.1
309 |
- webencodings==0.5.1
310 |
- websockets==14.1
311 |
- werkzeug==3.1.3
312 |
- widgetsnbextension==4.0.13
313 |
prefix: /home/zeus/miniconda3/envs/cloudspace
@@ -0,0 +1,11 @@
1 |
# Download the file using Invoke-WebRequest
2 |
$destination = "./pokemonclassification.zip"
3 |
$url = "https://www.kaggle.com/api/v1/datasets/download/lantian773030/pokemonclassification"
4 |
Invoke-WebRequest -Uri $url -OutFile $destination
5 |
6 |
# Extract the zip file to the specified folder
7 |
$extractPath = "./pokemonclassification"
8 |
Expand-Archive -Path $destination -DestinationPath $extractPath
9 |
10 |
# Remove the downloaded zip file (if not needed anymore)
11 |
Remove-Item $destination
@@ -0,0 +1,7 @@
1 |
2 |
curl -L -o ./pokemonclassification.zip https://www.kaggle.com/api/v1/datasets/download/lantian773030/pokemonclassification
3 |
4 |
unzip -d ./pokemonclassification pokemonclassification.zip
5 |
6 |
# Remove the downloaded zip file (if you don't need it anymore)
7 |
rm ./pokemonclassification.zip
@@ -0,0 +1,55 @@
1 |
from utils.inference_utils import find_images_from_path
2 |
import torch
3 |
import argparse
4 |
from utils.train_utils import initialize_model
5 |
6 |
7 |
def main():
8 |
parser = argparse.ArgumentParser(description="Image Inference")
9 |
10 |
11 |
12 |
help="Model name (resnet, alexnet, vgg, squeezenet, densenet)",
13 |
14 |
15 |
16 |
17 |
18 |
help="Path to the model weights",
19 |
20 |
21 |
22 |
23 |
24 |
help="Path to the image",
25 |
26 |
27 |
28 |
"--num_classes", type=int, help="Number of classes", default=150
29 |
30 |
31 |
"--label", type=str, help="Label to filter the images", default='Dragonair' # Krabby, Clefairy
32 |
33 |
34 |
"--num_correct", type=int, help="Number of correctly classified images", default=5
35 |
36 |
37 |
"--num_incorrect", type=int, help="Number of incorrectly classified images", default=5
38 |
39 |
40 |
args = parser.parse_args()
41 |
42 |
assert (args.model_name == "resnet"), "Only the ResNet is supported model for now"
43 |
44 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45 |
46 |
# Initialize the model
47 |
model = initialize_model(args.model_name, args.num_classes)
48 |
model = model.to(device)
49 |
50 |
# Load the model weights
51 |
model.load_state_dict(torch.load(args.model_weights, map_location=device))
52 |
find_images_from_path(args.image_path, model, device, args.num_correct, args.num_incorrect, args.label)
53 |
54 |
if __name__ == "__main__":
55 |
@@ -0,0 +1,3 @@
1 |
version https://git-lfs.github.com/spec/v1
2 |
oid sha256:32b56617770be9430d034b41ef9235bb938cd1db40a78bb15a7d579229c79511
3 |
size 20240
@@ -0,0 +1,70 @@
1 |
import torch
2 |
import argparse
3 |
from utils.inference_utils import preprocess_image, predict
4 |
from utils.train_utils import initialize_model
5 |
from utils.interpretability import lime_interpret_image_inference
6 |
from utils.data import CLASS_NAMES
7 |
8 |
9 |
def main():
10 |
parser = argparse.ArgumentParser(description="Image Inference")
11 |
12 |
13 |
14 |
help="Model name (resnet, alexnet, vgg, squeezenet, densenet)",
15 |
16 |
17 |
18 |
19 |
20 |
help="Path to the model weights",
21 |
22 |
23 |
24 |
25 |
26 |
help="Path to the image",
27 |
28 |
29 |
30 |
"--num_classes", type=int, help="Number of classes", default=150
31 |
32 |
33 |
34 |
35 |
help="Whether to run interpretability or not",
36 |
37 |
38 |
39 |
40 |
help="Whether to classify the image when saving the lime filter")
41 |
42 |
args = parser.parse_args()
43 |
44 |
if args.lime_interpretability:
45 |
assert (
46 |
args.model_name == "resnet"
47 |
), "Interpretability is only supported for ResNet model for now"
48 |
49 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50 |
51 |
# Initialize the model
52 |
model = initialize_model(args.model_name, args.num_classes)
53 |
model = model.to(device)
54 |
55 |
# Load the model weights
56 |
model.load_state_dict(torch.load(args.model_weights, map_location=device))
57 |
58 |
# Preprocess the image
59 |
image = preprocess_image(args.image_path, (224, 224)).to(device)
60 |
61 |
# Perform inference
62 |
preds = torch.max(predict(model, image), 1)[1]
63 |
print(f"Predicted class: {CLASS_NAMES[preds.item()]}")
64 |
65 |
if args.lime_interpretability:
66 |
lime_interpret_image_inference(args, model, image, device)
67 |
68 |
69 |
if __name__ == "__main__":
70 |
@@ -0,0 +1,174 @@
1 |
import torch.nn as nn
2 |
from torchvision import transforms
3 |
from utils.data import PokemonDataModule
4 |
from utils.train import initialize_model, train_and_evaluate
5 |
import torch
6 |
import torch.optim as optim
7 |
import mlflow
8 |
import argparse
9 |
import random
10 |
11 |
# The shape of the images that the models expects
12 |
IMG_SHAPE = (224, 224)
13 |
14 |
15 |
def parser_args():
16 |
parser = argparse.ArgumentParser(description="Pokemon Classification")
17 |
18 |
19 |
20 |
21 |
help="Path to the data directory",
22 |
23 |
24 |
25 |
26 |
27 |
help="Path to the indices file",
28 |
29 |
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
30 |
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
31 |
32 |
"--train_batch_size", type=int, default=128, help="train Batch size"
33 |
34 |
35 |
"--test_batch_size", type=int, default=512, help="test Batch size"
36 |
37 |
38 |
39 |
40 |
choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"],
41 |
42 |
help="Model to be used",
43 |
44 |
45 |
46 |
47 |
48 |
help="whether to freeze the backbone or not",
49 |
50 |
51 |
52 |
53 |
54 |
help="whether to use pretrained model or not",
55 |
56 |
57 |
58 |
59 |
60 |
help="Experiment ID to log the results",
61 |
62 |
return parser.parse_args()
63 |
64 |
65 |
if __name__ == "__main__":
66 |
args = parser_args()
67 |
68 |
pokemon_dataset = PokemonDataModule(args.data_dir)
69 |
NUM_CLASSES = len(pokemon_dataset.class_names)
70 |
71 |
# Get class names
72 |
print(f"Number of classes: {NUM_CLASSES}")
73 |
74 |
# You can only the use precomputed means and vars if using the same indices file ('indices_60_32.pkl')
75 |
if "indices_60_32.pkl" in args.indices_file:
76 |
chanel_means = torch.tensor([0.6062, 0.5889, 0.5550])
77 |
chanel_vars = torch.tensor([0.3284, 0.3115, 0.3266])
78 |
stats = {"mean": chanel_means, "std": chanel_vars}
79 |
_ = pokemon_dataset.prepare_data(
80 |
indices_file=args.indices_file, get_stats=False
81 |
82 |
83 |
stats = pokemon_dataset.prepare_data(
84 |
indices_file=args.indices_file, get_stats=True
85 |
86 |
87 |
print(f"Train dataset size: {len(pokemon_dataset.train_dataset)}")
88 |
print(f"Test dataset size: {len(pokemon_dataset.test_dataset)}")
89 |
90 |
# Transformations of data for testing
91 |
test_transform = transforms.Compose(
92 |
93 |
94 |
transforms.ToTensor(), # Convert PIL images to tensors
95 |
transforms.Normalize(**stats), # Normalize images using mean and std
96 |
97 |
98 |
99 |
# Data augmentations for training
100 |
train_transform = transforms.Compose(
101 |
102 |
103 |
104 |
105 |
transforms.RandomCrop(IMG_SHAPE, padding=4),
106 |
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
107 |
108 |
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
# get dataloaders
117 |
trainloader, testloader = pokemon_dataset.get_dataloaders(
118 |
119 |
120 |
121 |
122 |
123 |
124 |
pokemon_dataset.plot_examples(testloader, stats=stats)
125 |
126 |
pokemon_dataset.plot_examples(trainloader, stats=stats)
127 |
128 |
# Try with a finetuning a resnet for example
129 |
model = initialize_model(
130 |
131 |
132 |
133 |
134 |
135 |
136 |
# Print the model we just instantiated
137 |
138 |
139 |
# Model, criterion, optimizer
140 |
criterion = nn.CrossEntropyLoss()
141 |
optimizer = optim.Adam(model.parameters(), lr=args.lr)
142 |
143 |
# Device configuration
144 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
145 |
146 |
with mlflow.start_run(
147 |
148 |
run_name=f"{args.model}_{'finetuning' if not args.feature_extract else 'feature_extracting'}"
149 |
f"_{'pretrained' if args.use_pretrained else 'not_pretrained'}"
150 |
f"_{args.indices_file}_{random.randint(0, 1000)}",
151 |
) as run:
152 |
mlflow.log_param("epochs", args.epochs)
153 |
mlflow.log_param("lr", args.lr)
154 |
mlflow.log_param("train_batch_size", args.train_batch_size)
155 |
mlflow.log_param("test_batch_size", args.test_batch_size)
156 |
mlflow.log_param("model", args.model)
157 |
mlflow.log_param("feature_extract", args.feature_extract)
158 |
mlflow.log_param("use_pretrained", args.use_pretrained)
159 |
160 |
# Train and evaluate
161 |
history = train_and_evaluate(
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
# Save the model
172 |
torch.save(model.state_dict(), f"pokemon_{args.model}.pth")
173 |
174 |