Sifal commited on
Commit
c7f0510
1 Parent(s): 4402658

Upload 9 files

Browse files
Files changed (9) hide show
  1. .gitignore +7 -0
  2. README.md +135 -3
  3. environment.yaml +313 -0
  4. get_data.ps1 +11 -0
  5. get_data.sh +7 -0
  6. get_samples.py +55 -0
  7. indices_60_32.pkl +3 -0
  8. inference.py +70 -0
  9. train.py +174 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ .vscode
2
+ pokemonclassification
3
+ *.zip
4
+ trained_models/pokemon_vgg.pth
5
+ *.cpython-310.pyc
6
+ __pycache__
7
+ utils/__pycache__/data.cpython-310.pyc
README.md CHANGED
@@ -1,3 +1,135 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PokemonClassification
2
+
3
+ This repository explores the training of different models for a vision classification task, with a special focus made on reproducibility, and an attempt to a local interpretability of the decision made by a resnet model using LIME
4
+
5
+ ## Table of Contents
6
+
7
+ - [PokemonClassification](#pokemonclassification)
8
+ - [Table of Contents](#table-of-contents)
9
+ - [Installation](#installation)
10
+ - [Dataset](#dataset)
11
+ - [Training](#training)
12
+ - [Inference](#inference)
13
+ - [Generating Data Samples](#generating-data-samples)
14
+ - [Interpretability](#interpretability)
15
+ - [Contributing](#contributing)
16
+
17
+ ## Installation
18
+
19
+ 1. Clone the repository:
20
+
21
+ ```sh
22
+ git clone https://github.com/yourusername/PokemonClassification.git
23
+ cd PokemonClassification
24
+ ```
25
+
26
+ 2. Create a conda environment and activate it:
27
+
28
+ ```sh
29
+ conda env create -f environment.yaml
30
+ conda activate pokemonclassification
31
+ ```
32
+
33
+ ## Dataset
34
+
35
+ To get the data, use the appropriate script based on your operating system:
36
+
37
+ - On Linux-based systems:
38
+
39
+ ```shell
40
+ ./get_data.sh
41
+ ```
42
+
43
+ - On Windows:
44
+
45
+ ```shell
46
+ ./get_data.ps1
47
+ ```
48
+
49
+ ## Training
50
+
51
+ To train a model, use the `train.py` script. Here are the parameters you can specify:
52
+
53
+ ```python
54
+ def parser_args():
55
+ parser = argparse.ArgumentParser(description="Pokemon Classification")
56
+ parser.add_argument("--data_dir", type=str, default="./pokemonclassification/PokemonData", help="Path to the data directory")
57
+ parser.add_argument("--indices_file", type=str, default="indices_60_32.pkl", help="Path to the indices file")
58
+ parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
59
+ parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
60
+ parser.add_argument("--train_batch_size", type=int, default=128, help="train Batch size")
61
+ parser.add_argument("--test_batch_size", type=int, default=512, help="test Batch size")
62
+ parser.add_argument("--model", type=str, choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"], default="resnet", help="Model to be used")
63
+ parser.add_argument("--feature_extract", type=bool, default=True, help="whether to freeze the backbone or not")
64
+ parser.add_argument("--use_pretrained", type=bool, default=True, help="whether to use pretrained model or not")
65
+ parser.add_argument("--experiment_id", type=int, default=0, help="Experiment ID to log the results")
66
+ return parser.parse_args()
67
+ ```
68
+
69
+ Example:
70
+
71
+ ```shell
72
+ python train.py--model resnet --data_dir data/PokemonData --epochs 10 --train_batch_size 32 --test_batch_size 32
73
+ ```
74
+
75
+ ## Inference
76
+
77
+ To perform inference on a single image, use the `inference.py` script. Here are the parameters you can specify:
78
+
79
+ ```python
80
+ def main():
81
+ parser = argparse.ArgumentParser(description="Image Inference")
82
+ parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
83
+ parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
84
+ parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/Chansey/57ccf27cba024fac9531baa9f619ec62.jpg")
85
+ parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
86
+ parser.add_argument("--lime_interpretability", action="store_true", help="Whether to run interpretability or not")
87
+ parser.add_argument("--classify", action="store_true", help="Whether to classify the image when saving the lime filter")
88
+ args = parser.parse_args()
89
+
90
+ if args.lime_interpretability:
91
+ assert args.model_name == "resnet", "Interpretability is only supported for ResNet model for now"
92
+ ```
93
+
94
+ Example:
95
+
96
+ ```shell
97
+ python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10
98
+ ```
99
+
100
+ ## Generating Data Samples
101
+
102
+ To generate data samples, use the `get_samples.py` script. Here are the parameters you can specify:
103
+
104
+ ```python
105
+ def main():
106
+ parser = argparse.ArgumentParser(description="Generate Data Samples")
107
+ parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
108
+ parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
109
+ parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/")
110
+ parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
111
+ parser.add_argument("--label", type=str, help="Label to filter the images", default='Dragonair')
112
+ parser.add_argument("--num_correct", type=int, help="Number of correctly classified images", default=5)
113
+ parser.add_argument("--num_incorrect", type=int, help="Number of incorrectly classified images", default=5)
114
+ args = parser.parse_args()
115
+ ```
116
+
117
+ Example:
118
+
119
+ ```shell
120
+ python get_samples.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image_directory --num_classes 10 --label Pikachu --num_correct 5 --num_incorrect 5
121
+ ```
122
+
123
+ ## Interpretability
124
+
125
+ To interpret the model's predictions using LIME, use the `inference.py` script with the `--lime_interpretability` flag.
126
+
127
+ Example:
128
+
129
+ ```shell
130
+ python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10 --lime_interpretability
131
+ ```
132
+
133
+ ## Contributing
134
+
135
+ Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.
environment.yaml ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: cloudspace
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=5.1=1_gnu
8
+ - aiohappyeyeballs=2.4.3=py310h06a4308_0
9
+ - alembic=1.13.3=py310h06a4308_0
10
+ - aniso8601=9.0.1=pyhd3eb1b0_0
11
+ - arrow-cpp=16.1.0=hc1eb8f0_0
12
+ - attrs=24.2.0=py310h06a4308_0
13
+ - bcrypt=3.2.0=py310h5eee18b_1
14
+ - blas=1.0=openblas
15
+ - blinker=1.6.2=py310h06a4308_0
16
+ - boost-cpp=1.82.0=hdb19cb5_2
17
+ - bottleneck=1.4.2=py310ha9d4c09_0
18
+ - brotli=1.0.9=h5eee18b_8
19
+ - brotli-bin=1.0.9=h5eee18b_8
20
+ - brotli-python=1.0.9=py310h6a678d5_8
21
+ - bzip2=1.0.8=h5eee18b_6
22
+ - c-ares=1.19.1=h5eee18b_0
23
+ - ca-certificates=2024.11.26=h06a4308_0
24
+ - certifi=2024.8.30=py310h06a4308_0
25
+ - cffi=1.17.1=py310h1fdaa30_0
26
+ - click=8.1.7=py310h06a4308_0
27
+ - cloudpickle=3.0.0=py310h06a4308_0
28
+ - contourpy=1.3.1=py310hdb19cb5_0
29
+ - cryptography=43.0.3=py310h7825ff9_1
30
+ - databricks-sdk=0.33.0=py310h06a4308_0
31
+ - deprecated=1.2.13=py310h06a4308_0
32
+ - docker-py=7.1.0=py310h06a4308_0
33
+ - entrypoints=0.4=py310h06a4308_0
34
+ - flask=3.0.3=py310h06a4308_0
35
+ - freetype=2.12.1=h4a9f257_0
36
+ - frozenlist=1.5.0=py310h5eee18b_0
37
+ - gflags=2.2.2=h6a678d5_1
38
+ - gitdb=4.0.7=pyhd3eb1b0_0
39
+ - gitpython=3.1.43=py310h06a4308_0
40
+ - glog=0.5.0=h6a678d5_1
41
+ - graphene=3.3=py310h06a4308_0
42
+ - graphql-core=3.2.3=py310h06a4308_1
43
+ - graphql-relay=3.2.0=py310h06a4308_0
44
+ - greenlet=3.0.1=py310h6a678d5_0
45
+ - gunicorn=22.0.0=py310h06a4308_0
46
+ - icu=73.1=h6a678d5_0
47
+ - importlib-metadata=8.5.0=py310h06a4308_0
48
+ - itsdangerous=2.2.0=py310h06a4308_0
49
+ - jinja2=3.1.4=py310h06a4308_1
50
+ - joblib=1.4.2=py310h06a4308_0
51
+ - jpeg=9e=h5eee18b_3
52
+ - krb5=1.20.1=h143b758_1
53
+ - lcms2=2.12=h3be6417_0
54
+ - ld_impl_linux-64=2.40=h12ee557_0
55
+ - lerc=3.0=h295c915_0
56
+ - libabseil=20240116.2=cxx17_h6a678d5_0
57
+ - libboost=1.82.0=h109eef0_2
58
+ - libbrotlicommon=1.0.9=h5eee18b_8
59
+ - libbrotlidec=1.0.9=h5eee18b_8
60
+ - libbrotlienc=1.0.9=h5eee18b_8
61
+ - libcurl=8.9.1=h251f7ec_0
62
+ - libdeflate=1.17=h5eee18b_1
63
+ - libedit=3.1.20230828=h5eee18b_0
64
+ - libev=4.33=h7f8727e_1
65
+ - libevent=2.1.12=hdbd6064_1
66
+ - libffi=3.4.4=h6a678d5_1
67
+ - libgcc-ng=11.2.0=h1234567_1
68
+ - libgfortran-ng=11.2.0=h00389a5_1
69
+ - libgfortran5=11.2.0=h1234567_1
70
+ - libgomp=11.2.0=h1234567_1
71
+ - libgrpc=1.62.2=h2d74bed_0
72
+ - libnghttp2=1.57.0=h2d74bed_0
73
+ - libopenblas=0.3.21=h043d6bf_0
74
+ - libpng=1.6.39=h5eee18b_0
75
+ - libprotobuf=4.25.3=he621ea3_0
76
+ - libsodium=1.0.18=h7b6447c_0
77
+ - libssh2=1.11.1=h251f7ec_0
78
+ - libstdcxx-ng=11.2.0=h1234567_1
79
+ - libthrift=0.15.0=h1795dd8_2
80
+ - libtiff=4.5.1=h6a678d5_0
81
+ - libuuid=1.41.5=h5eee18b_0
82
+ - libwebp-base=1.3.2=h5eee18b_1
83
+ - lz4-c=1.9.4=h6a678d5_1
84
+ - mako=1.2.3=py310h06a4308_0
85
+ - matplotlib-base=3.9.2=py310hbfdbfaf_1
86
+ - mlflow=2.18.0=hff52083_0
87
+ - mlflow-skinny=2.18.0=py310hff52083_0
88
+ - mlflow-ui=2.18.0=py310hff52083_0
89
+ - multidict=6.1.0=py310h5eee18b_0
90
+ - ncurses=6.4=h6a678d5_0
91
+ - numexpr=2.10.1=py310hd28fd6d_0
92
+ - numpy=1.26.4=py310heeff2f4_0
93
+ - numpy-base=1.26.4=py310h8a23956_0
94
+ - openjpeg=2.5.2=he7f1fd0_0
95
+ - openssl=3.0.15=h5eee18b_0
96
+ - opentelemetry-api=1.16.0=pyhd8ed1ab_0
97
+ - opentelemetry-sdk=1.16.0=pyhd8ed1ab_0
98
+ - opentelemetry-semantic-conventions=0.37b0=pyhd8ed1ab_0
99
+ - orc=2.0.1=h2d29ad5_0
100
+ - paramiko=3.5.0=py310h06a4308_0
101
+ - pillow=11.0.0=py310hfdbf927_0
102
+ - prometheus_client=0.21.0=py310h06a4308_0
103
+ - prometheus_flask_exporter=0.22.4=py310h06a4308_0
104
+ - propcache=0.2.0=py310h5eee18b_0
105
+ - pyarrow=16.1.0=py310h1128e8f_0
106
+ - pynacl=1.5.0=py310h5eee18b_0
107
+ - pyopenssl=24.2.1=py310h06a4308_0
108
+ - pyparsing=3.2.0=py310h06a4308_0
109
+ - pysocks=1.7.1=py310h06a4308_0
110
+ - python=3.10.15=he870216_1
111
+ - python-dateutil=2.9.0post0=py310h06a4308_2
112
+ - python-tzdata=2023.3=pyhd3eb1b0_0
113
+ - python_abi=3.10=2_cp310
114
+ - pyyaml=6.0.2=py310h5eee18b_0
115
+ - querystring_parser=1.2.4=py310h06a4308_0
116
+ - re2=2022.04.01=h295c915_0
117
+ - readline=8.2=h5eee18b_0
118
+ - requests=2.32.3=py310h06a4308_1
119
+ - s2n=1.3.27=hdbd6064_0
120
+ - setuptools=75.1.0=py310h06a4308_0
121
+ - six=1.16.0=pyhd3eb1b0_1
122
+ - smmap=4.0.0=pyhd3eb1b0_0
123
+ - snappy=1.2.1=h6a678d5_0
124
+ - sqlalchemy=2.0.34=py310h00e1ef3_0
125
+ - sqlite=3.45.3=h5eee18b_0
126
+ - sqlparse=0.4.4=py310h06a4308_0
127
+ - threadpoolctl=3.5.0=py310h2f386ee_0
128
+ - tk=8.6.14=h39e8969_0
129
+ - typing_extensions=4.11.0=py310h06a4308_0
130
+ - unicodedata2=15.1.0=py310h5eee18b_0
131
+ - urllib3=2.2.3=py310h06a4308_0
132
+ - utf8proc=2.6.1=h5eee18b_1
133
+ - websocket-client=1.8.0=py310h06a4308_0
134
+ - wheel=0.44.0=py310h06a4308_0
135
+ - wrapt=1.14.1=py310h5eee18b_0
136
+ - xz=5.4.6=h5eee18b_1
137
+ - yaml=0.2.5=h7b6447c_0
138
+ - yarl=1.18.0=py310h5eee18b_0
139
+ - zipp=3.21.0=py310h06a4308_0
140
+ - zlib=1.2.13=h5eee18b_1
141
+ - zstd=1.5.6=hc292b87_0
142
+ - pip:
143
+ - absl-py==2.1.0
144
+ - aiohttp==3.11.7
145
+ - aiosignal==1.3.1
146
+ - annotated-types==0.7.0
147
+ - anyio==4.6.2.post1
148
+ - argon2-cffi==23.1.0
149
+ - argon2-cffi-bindings==21.2.0
150
+ - arrow==1.3.0
151
+ - asttokens==2.4.1
152
+ - async-lru==2.0.4
153
+ - async-timeout==5.0.1
154
+ - babel==2.16.0
155
+ - backoff==2.2.1
156
+ - beautifulsoup4==4.12.3
157
+ - bleach==6.2.0
158
+ - boto3==1.35.70
159
+ - botocore==1.35.70
160
+ - cachetools==5.5.0
161
+ - charset-normalizer==3.4.0
162
+ - comm==0.2.2
163
+ - cycler==0.12.1
164
+ - debugpy==1.8.9
165
+ - decorator==5.1.1
166
+ - defusedxml==0.7.1
167
+ - exceptiongroup==1.2.2
168
+ - executing==2.1.0
169
+ - fastapi==0.115.5
170
+ - fastjsonschema==2.20.0
171
+ - filelock==3.16.1
172
+ - fire==0.7.0
173
+ - fonttools==4.55.0
174
+ - fqdn==1.5.1
175
+ - fsspec==2024.10.0
176
+ - git-filter-repo==2.47.0
177
+ - google-auth==2.36.0
178
+ - google-auth-oauthlib==1.2.1
179
+ - grpcio==1.68.0
180
+ - h11==0.14.0
181
+ - httpcore==1.0.7
182
+ - httptools==0.6.4
183
+ - httpx==0.27.2
184
+ - idna==3.10
185
+ - imageio==2.36.1
186
+ - ipykernel==6.26.0
187
+ - ipython==8.17.2
188
+ - ipywidgets==8.1.1
189
+ - isoduration==20.11.0
190
+ - jedi==0.19.2
191
+ - jmespath==1.0.1
192
+ - json5==0.10.0
193
+ - jsonpointer==3.0.0
194
+ - jsonschema==4.23.0
195
+ - jsonschema-specifications==2024.10.1
196
+ - jupyter-client==8.6.3
197
+ - jupyter-core==5.7.2
198
+ - jupyter-events==0.10.0
199
+ - jupyter-lsp==2.2.5
200
+ - jupyter-server==2.14.2
201
+ - jupyter-server-terminals==0.5.3
202
+ - jupyterlab==4.2.0
203
+ - jupyterlab-pygments==0.3.0
204
+ - jupyterlab-server==2.27.3
205
+ - jupyterlab-widgets==3.0.13
206
+ - kiwisolver==1.4.7
207
+ - lazy-loader==0.4
208
+ - lightning==2.4.0
209
+ - lightning-cloud==0.5.70
210
+ - lightning-sdk==0.1.30
211
+ - lightning-utilities==0.11.9
212
+ - lime==0.2.0.1
213
+ - litdata==0.2.32
214
+ - litserve==0.2.5
215
+ - markdown==3.7
216
+ - markdown-it-py==3.0.0
217
+ - markupsafe==3.0.2
218
+ - matplotlib==3.8.2
219
+ - matplotlib-inline==0.1.7
220
+ - mdurl==0.1.2
221
+ - mistune==3.0.2
222
+ - mpmath==1.3.0
223
+ - nbclient==0.10.0
224
+ - nbconvert==7.16.4
225
+ - nbformat==5.10.4
226
+ - nest-asyncio==1.6.0
227
+ - networkx==3.4.2
228
+ - notebook-shim==0.2.4
229
+ - nvidia-cublas-cu12==12.1.3.1
230
+ - nvidia-cuda-cupti-cu12==12.1.105
231
+ - nvidia-cuda-nvrtc-cu12==12.1.105
232
+ - nvidia-cuda-runtime-cu12==12.1.105
233
+ - nvidia-cudnn-cu12==8.9.2.26
234
+ - nvidia-cufft-cu12==11.0.2.54
235
+ - nvidia-curand-cu12==10.3.2.106
236
+ - nvidia-cusolver-cu12==11.4.5.107
237
+ - nvidia-cusparse-cu12==12.1.0.106
238
+ - nvidia-nccl-cu12==2.19.3
239
+ - nvidia-nvjitlink-cu12==12.6.85
240
+ - nvidia-nvtx-cu12==12.1.105
241
+ - oauthlib==3.2.2
242
+ - overrides==7.7.0
243
+ - packaging==24.2
244
+ - pandas==2.1.4
245
+ - pandocfilters==1.5.1
246
+ - parso==0.8.4
247
+ - pexpect==4.9.0
248
+ - pip==24.3.1
249
+ - platformdirs==4.3.6
250
+ - prompt-toolkit==3.0.48
251
+ - protobuf==4.23.4
252
+ - psutil==6.1.0
253
+ - ptyprocess==0.7.0
254
+ - pure-eval==0.2.3
255
+ - pyasn1==0.6.1
256
+ - pyasn1-modules==0.4.1
257
+ - pycparser==2.22
258
+ - pydantic==2.10.2
259
+ - pydantic-core==2.27.1
260
+ - pygments==2.18.0
261
+ - pyjwt==2.10.0
262
+ - python-dotenv==1.0.1
263
+ - python-json-logger==2.0.7
264
+ - python-multipart==0.0.17
265
+ - pytorch-lightning==2.4.0
266
+ - pytz==2024.2
267
+ - pyzmq==26.2.0
268
+ - referencing==0.35.1
269
+ - requests-oauthlib==2.0.0
270
+ - rfc3339-validator==0.1.4
271
+ - rfc3986-validator==0.1.1
272
+ - rich==13.9.4
273
+ - rpds-py==0.21.0
274
+ - rsa==4.9
275
+ - s3transfer==0.10.4
276
+ - scikit-image==0.24.0
277
+ - scikit-learn==1.3.2
278
+ - scipy==1.11.4
279
+ - send2trash==1.8.3
280
+ - simple-term-menu==1.6.5
281
+ - sniffio==1.3.1
282
+ - soupsieve==2.6
283
+ - stack-data==0.6.3
284
+ - starlette==0.41.3
285
+ - sympy==1.13.3
286
+ - tensorboard==2.15.1
287
+ - tensorboard-data-server==0.7.2
288
+ - termcolor==2.5.0
289
+ - terminado==0.18.1
290
+ - tifffile==2024.9.20
291
+ - tinycss2==1.4.0
292
+ - tomli==2.1.0
293
+ - torch==2.2.1+cu121
294
+ - torchmetrics==1.3.1
295
+ - torchvision==0.17.1+cu121
296
+ - tornado==6.4.2
297
+ - tqdm==4.67.1
298
+ - traitlets==5.14.3
299
+ - triton==2.2.0
300
+ - types-python-dateutil==2.9.0.20241003
301
+ - typing-extensions==4.12.2
302
+ - tzdata==2024.2
303
+ - uri-template==1.3.0
304
+ - uvicorn==0.32.1
305
+ - uvloop==0.21.0
306
+ - watchfiles==1.0.0
307
+ - wcwidth==0.2.13
308
+ - webcolors==24.11.1
309
+ - webencodings==0.5.1
310
+ - websockets==14.1
311
+ - werkzeug==3.1.3
312
+ - widgetsnbextension==4.0.13
313
+ prefix: /home/zeus/miniconda3/envs/cloudspace
get_data.ps1 ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Download the file using Invoke-WebRequest
2
+ $destination = "./pokemonclassification.zip"
3
+ $url = "https://www.kaggle.com/api/v1/datasets/download/lantian773030/pokemonclassification"
4
+ Invoke-WebRequest -Uri $url -OutFile $destination
5
+
6
+ # Extract the zip file to the specified folder
7
+ $extractPath = "./pokemonclassification"
8
+ Expand-Archive -Path $destination -DestinationPath $extractPath
9
+
10
+ # Remove the downloaded zip file (if not needed anymore)
11
+ Remove-Item $destination
get_data.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ curl -L -o ./pokemonclassification.zip https://www.kaggle.com/api/v1/datasets/download/lantian773030/pokemonclassification
3
+
4
+ unzip -d ./pokemonclassification pokemonclassification.zip
5
+
6
+ # Remove the downloaded zip file (if you don't need it anymore)
7
+ rm ./pokemonclassification.zip
get_samples.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.inference_utils import find_images_from_path
2
+ import torch
3
+ import argparse
4
+ from utils.train_utils import initialize_model
5
+
6
+
7
+ def main():
8
+ parser = argparse.ArgumentParser(description="Image Inference")
9
+ parser.add_argument(
10
+ "--model_name",
11
+ type=str,
12
+ help="Model name (resnet, alexnet, vgg, squeezenet, densenet)",
13
+ default="resnet",
14
+ )
15
+ parser.add_argument(
16
+ "--model_weights",
17
+ type=str,
18
+ help="Path to the model weights",
19
+ default="./trained_models/pokemon_resnet.pth",
20
+ )
21
+ parser.add_argument(
22
+ "--image_path",
23
+ type=str,
24
+ help="Path to the image",
25
+ default="./pokemonclassification/PokemonData/",
26
+ )
27
+ parser.add_argument(
28
+ "--num_classes", type=int, help="Number of classes", default=150
29
+ )
30
+ parser.add_argument(
31
+ "--label", type=str, help="Label to filter the images", default='Dragonair' # Krabby, Clefairy
32
+ )
33
+ parser.add_argument(
34
+ "--num_correct", type=int, help="Number of correctly classified images", default=5
35
+ )
36
+ parser.add_argument(
37
+ "--num_incorrect", type=int, help="Number of incorrectly classified images", default=5
38
+ )
39
+
40
+ args = parser.parse_args()
41
+
42
+ assert (args.model_name == "resnet"), "Only the ResNet is supported model for now"
43
+
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+
46
+ # Initialize the model
47
+ model = initialize_model(args.model_name, args.num_classes)
48
+ model = model.to(device)
49
+
50
+ # Load the model weights
51
+ model.load_state_dict(torch.load(args.model_weights, map_location=device))
52
+ find_images_from_path(args.image_path, model, device, args.num_correct, args.num_incorrect, args.label)
53
+
54
+ if __name__ == "__main__":
55
+ main()
indices_60_32.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:32b56617770be9430d034b41ef9235bb938cd1db40a78bb15a7d579229c79511
3
+ size 20240
inference.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ from utils.inference_utils import preprocess_image, predict
4
+ from utils.train_utils import initialize_model
5
+ from utils.interpretability import lime_interpret_image_inference
6
+ from utils.data import CLASS_NAMES
7
+
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser(description="Image Inference")
11
+ parser.add_argument(
12
+ "--model_name",
13
+ type=str,
14
+ help="Model name (resnet, alexnet, vgg, squeezenet, densenet)",
15
+ default="resnet",
16
+ )
17
+ parser.add_argument(
18
+ "--model_weights",
19
+ type=str,
20
+ help="Path to the model weights",
21
+ default="./trained_models/pokemon_resnet.pth",
22
+ )
23
+ parser.add_argument(
24
+ "--image_path",
25
+ type=str,
26
+ help="Path to the image",
27
+ default="./pokemonclassification/PokemonData/Chansey/57ccf27cba024fac9531baa9f619ec62.jpg",
28
+ )
29
+ parser.add_argument(
30
+ "--num_classes", type=int, help="Number of classes", default=150
31
+ )
32
+ parser.add_argument(
33
+ "--lime_interpretability",
34
+ action="store_true",
35
+ help="Whether to run interpretability or not",
36
+ )
37
+ parser.add_argument(
38
+ "--classify",
39
+ action="store_true",
40
+ help="Whether to classify the image when saving the lime filter")
41
+
42
+ args = parser.parse_args()
43
+
44
+ if args.lime_interpretability:
45
+ assert (
46
+ args.model_name == "resnet"
47
+ ), "Interpretability is only supported for ResNet model for now"
48
+
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+
51
+ # Initialize the model
52
+ model = initialize_model(args.model_name, args.num_classes)
53
+ model = model.to(device)
54
+
55
+ # Load the model weights
56
+ model.load_state_dict(torch.load(args.model_weights, map_location=device))
57
+
58
+ # Preprocess the image
59
+ image = preprocess_image(args.image_path, (224, 224)).to(device)
60
+
61
+ # Perform inference
62
+ preds = torch.max(predict(model, image), 1)[1]
63
+ print(f"Predicted class: {CLASS_NAMES[preds.item()]}")
64
+
65
+ if args.lime_interpretability:
66
+ lime_interpret_image_inference(args, model, image, device)
67
+
68
+
69
+ if __name__ == "__main__":
70
+ main()
train.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torchvision import transforms
3
+ from utils.data import PokemonDataModule
4
+ from utils.train import initialize_model, train_and_evaluate
5
+ import torch
6
+ import torch.optim as optim
7
+ import mlflow
8
+ import argparse
9
+ import random
10
+
11
+ # The shape of the images that the models expects
12
+ IMG_SHAPE = (224, 224)
13
+
14
+
15
+ def parser_args():
16
+ parser = argparse.ArgumentParser(description="Pokemon Classification")
17
+ parser.add_argument(
18
+ "--data_dir",
19
+ type=str,
20
+ default="./pokemonclassification/PokemonData",
21
+ help="Path to the data directory",
22
+ )
23
+ parser.add_argument(
24
+ "--indices_file",
25
+ type=str,
26
+ default="indices_60_32.pkl",
27
+ help="Path to the indices file",
28
+ )
29
+ parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
30
+ parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
31
+ parser.add_argument(
32
+ "--train_batch_size", type=int, default=128, help="train Batch size"
33
+ )
34
+ parser.add_argument(
35
+ "--test_batch_size", type=int, default=512, help="test Batch size"
36
+ )
37
+ parser.add_argument(
38
+ "--model",
39
+ type=str,
40
+ choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"],
41
+ default="resnet",
42
+ help="Model to be used",
43
+ )
44
+ parser.add_argument(
45
+ "--feature_extract",
46
+ type=bool,
47
+ default=True,
48
+ help="whether to freeze the backbone or not",
49
+ )
50
+ parser.add_argument(
51
+ "--use_pretrained",
52
+ type=bool,
53
+ default=True,
54
+ help="whether to use pretrained model or not",
55
+ )
56
+ parser.add_argument(
57
+ "--experiment_id",
58
+ type=int,
59
+ default=0,
60
+ help="Experiment ID to log the results",
61
+ )
62
+ return parser.parse_args()
63
+
64
+
65
+ if __name__ == "__main__":
66
+ args = parser_args()
67
+
68
+ pokemon_dataset = PokemonDataModule(args.data_dir)
69
+ NUM_CLASSES = len(pokemon_dataset.class_names)
70
+
71
+ # Get class names
72
+ print(f"Number of classes: {NUM_CLASSES}")
73
+
74
+ # You can only the use precomputed means and vars if using the same indices file ('indices_60_32.pkl')
75
+ if "indices_60_32.pkl" in args.indices_file:
76
+ chanel_means = torch.tensor([0.6062, 0.5889, 0.5550])
77
+ chanel_vars = torch.tensor([0.3284, 0.3115, 0.3266])
78
+ stats = {"mean": chanel_means, "std": chanel_vars}
79
+ _ = pokemon_dataset.prepare_data(
80
+ indices_file=args.indices_file, get_stats=False
81
+ )
82
+ else:
83
+ stats = pokemon_dataset.prepare_data(
84
+ indices_file=args.indices_file, get_stats=True
85
+ )
86
+
87
+ print(f"Train dataset size: {len(pokemon_dataset.train_dataset)}")
88
+ print(f"Test dataset size: {len(pokemon_dataset.test_dataset)}")
89
+
90
+ # Transformations of data for testing
91
+ test_transform = transforms.Compose(
92
+ [
93
+ transforms.Resize(IMG_SHAPE),
94
+ transforms.ToTensor(), # Convert PIL images to tensors
95
+ transforms.Normalize(**stats), # Normalize images using mean and std
96
+ ]
97
+ )
98
+
99
+ # Data augmentations for training
100
+ train_transform = transforms.Compose(
101
+ [
102
+ transforms.Resize(IMG_SHAPE),
103
+ transforms.RandomRotation(10),
104
+ transforms.RandomHorizontalFlip(),
105
+ transforms.RandomCrop(IMG_SHAPE, padding=4),
106
+ transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
107
+ transforms.ColorJitter(
108
+ brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2
109
+ ),
110
+ transforms.RandomGrayscale(p=0.2),
111
+ transforms.ToTensor(),
112
+ transforms.Normalize(**stats),
113
+ ]
114
+ )
115
+
116
+ # get dataloaders
117
+ trainloader, testloader = pokemon_dataset.get_dataloaders(
118
+ train_transform=train_transform,
119
+ test_transform=test_transform,
120
+ train_batch_size=args.train_batch_size,
121
+ test_batch_size=args.test_batch_size,
122
+ )
123
+
124
+ pokemon_dataset.plot_examples(testloader, stats=stats)
125
+
126
+ pokemon_dataset.plot_examples(trainloader, stats=stats)
127
+
128
+ # Try with a finetuning a resnet for example
129
+ model = initialize_model(
130
+ args.model,
131
+ NUM_CLASSES,
132
+ feature_extract=args.feature_extract,
133
+ use_pretrained=args.use_pretrained,
134
+ )
135
+
136
+ # Print the model we just instantiated
137
+ print(model)
138
+
139
+ # Model, criterion, optimizer
140
+ criterion = nn.CrossEntropyLoss()
141
+ optimizer = optim.Adam(model.parameters(), lr=args.lr)
142
+
143
+ # Device configuration
144
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
145
+
146
+ with mlflow.start_run(
147
+ experiment_id=args.experiment_id,
148
+ run_name=f"{args.model}_{'finetuning' if not args.feature_extract else 'feature_extracting'}"
149
+ f"_{'pretrained' if args.use_pretrained else 'not_pretrained'}"
150
+ f"_{args.indices_file}_{random.randint(0, 1000)}",
151
+ ) as run:
152
+ mlflow.log_param("epochs", args.epochs)
153
+ mlflow.log_param("lr", args.lr)
154
+ mlflow.log_param("train_batch_size", args.train_batch_size)
155
+ mlflow.log_param("test_batch_size", args.test_batch_size)
156
+ mlflow.log_param("model", args.model)
157
+ mlflow.log_param("feature_extract", args.feature_extract)
158
+ mlflow.log_param("use_pretrained", args.use_pretrained)
159
+
160
+ # Train and evaluate
161
+ history = train_and_evaluate(
162
+ model=model,
163
+ trainloader=trainloader,
164
+ testloader=testloader,
165
+ criterion=criterion,
166
+ optimizer=optimizer,
167
+ device=device,
168
+ epochs=args.epochs,
169
+ use_mlflow=True,
170
+ )
171
+ # Save the model
172
+ torch.save(model.state_dict(), f"pokemon_{args.model}.pth")
173
+ mlflow.log_artifact(f"pokemon_{args.model}.pth")
174
+ mlflow.end_run()