Upload 9 files
Browse files- .gitignore +7 -0
- README.md +135 -3
- environment.yaml +313 -0
- get_data.ps1 +11 -0
- get_data.sh +7 -0
- get_samples.py +55 -0
- indices_60_32.pkl +3 -0
- inference.py +70 -0
- train.py +174 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.vscode
|
2 |
+
pokemonclassification
|
3 |
+
*.zip
|
4 |
+
trained_models/pokemon_vgg.pth
|
5 |
+
*.cpython-310.pyc
|
6 |
+
__pycache__
|
7 |
+
utils/__pycache__/data.cpython-310.pyc
|
README.md
CHANGED
@@ -1,3 +1,135 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PokemonClassification
|
2 |
+
|
3 |
+
This repository explores the training of different models for a vision classification task, with a special focus made on reproducibility, and an attempt to a local interpretability of the decision made by a resnet model using LIME
|
4 |
+
|
5 |
+
## Table of Contents
|
6 |
+
|
7 |
+
- [PokemonClassification](#pokemonclassification)
|
8 |
+
- [Table of Contents](#table-of-contents)
|
9 |
+
- [Installation](#installation)
|
10 |
+
- [Dataset](#dataset)
|
11 |
+
- [Training](#training)
|
12 |
+
- [Inference](#inference)
|
13 |
+
- [Generating Data Samples](#generating-data-samples)
|
14 |
+
- [Interpretability](#interpretability)
|
15 |
+
- [Contributing](#contributing)
|
16 |
+
|
17 |
+
## Installation
|
18 |
+
|
19 |
+
1. Clone the repository:
|
20 |
+
|
21 |
+
```sh
|
22 |
+
git clone https://github.com/yourusername/PokemonClassification.git
|
23 |
+
cd PokemonClassification
|
24 |
+
```
|
25 |
+
|
26 |
+
2. Create a conda environment and activate it:
|
27 |
+
|
28 |
+
```sh
|
29 |
+
conda env create -f environment.yaml
|
30 |
+
conda activate pokemonclassification
|
31 |
+
```
|
32 |
+
|
33 |
+
## Dataset
|
34 |
+
|
35 |
+
To get the data, use the appropriate script based on your operating system:
|
36 |
+
|
37 |
+
- On Linux-based systems:
|
38 |
+
|
39 |
+
```shell
|
40 |
+
./get_data.sh
|
41 |
+
```
|
42 |
+
|
43 |
+
- On Windows:
|
44 |
+
|
45 |
+
```shell
|
46 |
+
./get_data.ps1
|
47 |
+
```
|
48 |
+
|
49 |
+
## Training
|
50 |
+
|
51 |
+
To train a model, use the `train.py` script. Here are the parameters you can specify:
|
52 |
+
|
53 |
+
```python
|
54 |
+
def parser_args():
|
55 |
+
parser = argparse.ArgumentParser(description="Pokemon Classification")
|
56 |
+
parser.add_argument("--data_dir", type=str, default="./pokemonclassification/PokemonData", help="Path to the data directory")
|
57 |
+
parser.add_argument("--indices_file", type=str, default="indices_60_32.pkl", help="Path to the indices file")
|
58 |
+
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
|
59 |
+
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
|
60 |
+
parser.add_argument("--train_batch_size", type=int, default=128, help="train Batch size")
|
61 |
+
parser.add_argument("--test_batch_size", type=int, default=512, help="test Batch size")
|
62 |
+
parser.add_argument("--model", type=str, choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"], default="resnet", help="Model to be used")
|
63 |
+
parser.add_argument("--feature_extract", type=bool, default=True, help="whether to freeze the backbone or not")
|
64 |
+
parser.add_argument("--use_pretrained", type=bool, default=True, help="whether to use pretrained model or not")
|
65 |
+
parser.add_argument("--experiment_id", type=int, default=0, help="Experiment ID to log the results")
|
66 |
+
return parser.parse_args()
|
67 |
+
```
|
68 |
+
|
69 |
+
Example:
|
70 |
+
|
71 |
+
```shell
|
72 |
+
python train.py--model resnet --data_dir data/PokemonData --epochs 10 --train_batch_size 32 --test_batch_size 32
|
73 |
+
```
|
74 |
+
|
75 |
+
## Inference
|
76 |
+
|
77 |
+
To perform inference on a single image, use the `inference.py` script. Here are the parameters you can specify:
|
78 |
+
|
79 |
+
```python
|
80 |
+
def main():
|
81 |
+
parser = argparse.ArgumentParser(description="Image Inference")
|
82 |
+
parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
|
83 |
+
parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
|
84 |
+
parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/Chansey/57ccf27cba024fac9531baa9f619ec62.jpg")
|
85 |
+
parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
|
86 |
+
parser.add_argument("--lime_interpretability", action="store_true", help="Whether to run interpretability or not")
|
87 |
+
parser.add_argument("--classify", action="store_true", help="Whether to classify the image when saving the lime filter")
|
88 |
+
args = parser.parse_args()
|
89 |
+
|
90 |
+
if args.lime_interpretability:
|
91 |
+
assert args.model_name == "resnet", "Interpretability is only supported for ResNet model for now"
|
92 |
+
```
|
93 |
+
|
94 |
+
Example:
|
95 |
+
|
96 |
+
```shell
|
97 |
+
python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10
|
98 |
+
```
|
99 |
+
|
100 |
+
## Generating Data Samples
|
101 |
+
|
102 |
+
To generate data samples, use the `get_samples.py` script. Here are the parameters you can specify:
|
103 |
+
|
104 |
+
```python
|
105 |
+
def main():
|
106 |
+
parser = argparse.ArgumentParser(description="Generate Data Samples")
|
107 |
+
parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
|
108 |
+
parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
|
109 |
+
parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/")
|
110 |
+
parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
|
111 |
+
parser.add_argument("--label", type=str, help="Label to filter the images", default='Dragonair')
|
112 |
+
parser.add_argument("--num_correct", type=int, help="Number of correctly classified images", default=5)
|
113 |
+
parser.add_argument("--num_incorrect", type=int, help="Number of incorrectly classified images", default=5)
|
114 |
+
args = parser.parse_args()
|
115 |
+
```
|
116 |
+
|
117 |
+
Example:
|
118 |
+
|
119 |
+
```shell
|
120 |
+
python get_samples.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image_directory --num_classes 10 --label Pikachu --num_correct 5 --num_incorrect 5
|
121 |
+
```
|
122 |
+
|
123 |
+
## Interpretability
|
124 |
+
|
125 |
+
To interpret the model's predictions using LIME, use the `inference.py` script with the `--lime_interpretability` flag.
|
126 |
+
|
127 |
+
Example:
|
128 |
+
|
129 |
+
```shell
|
130 |
+
python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10 --lime_interpretability
|
131 |
+
```
|
132 |
+
|
133 |
+
## Contributing
|
134 |
+
|
135 |
+
Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.
|
environment.yaml
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: cloudspace
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- _openmp_mutex=5.1=1_gnu
|
8 |
+
- aiohappyeyeballs=2.4.3=py310h06a4308_0
|
9 |
+
- alembic=1.13.3=py310h06a4308_0
|
10 |
+
- aniso8601=9.0.1=pyhd3eb1b0_0
|
11 |
+
- arrow-cpp=16.1.0=hc1eb8f0_0
|
12 |
+
- attrs=24.2.0=py310h06a4308_0
|
13 |
+
- bcrypt=3.2.0=py310h5eee18b_1
|
14 |
+
- blas=1.0=openblas
|
15 |
+
- blinker=1.6.2=py310h06a4308_0
|
16 |
+
- boost-cpp=1.82.0=hdb19cb5_2
|
17 |
+
- bottleneck=1.4.2=py310ha9d4c09_0
|
18 |
+
- brotli=1.0.9=h5eee18b_8
|
19 |
+
- brotli-bin=1.0.9=h5eee18b_8
|
20 |
+
- brotli-python=1.0.9=py310h6a678d5_8
|
21 |
+
- bzip2=1.0.8=h5eee18b_6
|
22 |
+
- c-ares=1.19.1=h5eee18b_0
|
23 |
+
- ca-certificates=2024.11.26=h06a4308_0
|
24 |
+
- certifi=2024.8.30=py310h06a4308_0
|
25 |
+
- cffi=1.17.1=py310h1fdaa30_0
|
26 |
+
- click=8.1.7=py310h06a4308_0
|
27 |
+
- cloudpickle=3.0.0=py310h06a4308_0
|
28 |
+
- contourpy=1.3.1=py310hdb19cb5_0
|
29 |
+
- cryptography=43.0.3=py310h7825ff9_1
|
30 |
+
- databricks-sdk=0.33.0=py310h06a4308_0
|
31 |
+
- deprecated=1.2.13=py310h06a4308_0
|
32 |
+
- docker-py=7.1.0=py310h06a4308_0
|
33 |
+
- entrypoints=0.4=py310h06a4308_0
|
34 |
+
- flask=3.0.3=py310h06a4308_0
|
35 |
+
- freetype=2.12.1=h4a9f257_0
|
36 |
+
- frozenlist=1.5.0=py310h5eee18b_0
|
37 |
+
- gflags=2.2.2=h6a678d5_1
|
38 |
+
- gitdb=4.0.7=pyhd3eb1b0_0
|
39 |
+
- gitpython=3.1.43=py310h06a4308_0
|
40 |
+
- glog=0.5.0=h6a678d5_1
|
41 |
+
- graphene=3.3=py310h06a4308_0
|
42 |
+
- graphql-core=3.2.3=py310h06a4308_1
|
43 |
+
- graphql-relay=3.2.0=py310h06a4308_0
|
44 |
+
- greenlet=3.0.1=py310h6a678d5_0
|
45 |
+
- gunicorn=22.0.0=py310h06a4308_0
|
46 |
+
- icu=73.1=h6a678d5_0
|
47 |
+
- importlib-metadata=8.5.0=py310h06a4308_0
|
48 |
+
- itsdangerous=2.2.0=py310h06a4308_0
|
49 |
+
- jinja2=3.1.4=py310h06a4308_1
|
50 |
+
- joblib=1.4.2=py310h06a4308_0
|
51 |
+
- jpeg=9e=h5eee18b_3
|
52 |
+
- krb5=1.20.1=h143b758_1
|
53 |
+
- lcms2=2.12=h3be6417_0
|
54 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
55 |
+
- lerc=3.0=h295c915_0
|
56 |
+
- libabseil=20240116.2=cxx17_h6a678d5_0
|
57 |
+
- libboost=1.82.0=h109eef0_2
|
58 |
+
- libbrotlicommon=1.0.9=h5eee18b_8
|
59 |
+
- libbrotlidec=1.0.9=h5eee18b_8
|
60 |
+
- libbrotlienc=1.0.9=h5eee18b_8
|
61 |
+
- libcurl=8.9.1=h251f7ec_0
|
62 |
+
- libdeflate=1.17=h5eee18b_1
|
63 |
+
- libedit=3.1.20230828=h5eee18b_0
|
64 |
+
- libev=4.33=h7f8727e_1
|
65 |
+
- libevent=2.1.12=hdbd6064_1
|
66 |
+
- libffi=3.4.4=h6a678d5_1
|
67 |
+
- libgcc-ng=11.2.0=h1234567_1
|
68 |
+
- libgfortran-ng=11.2.0=h00389a5_1
|
69 |
+
- libgfortran5=11.2.0=h1234567_1
|
70 |
+
- libgomp=11.2.0=h1234567_1
|
71 |
+
- libgrpc=1.62.2=h2d74bed_0
|
72 |
+
- libnghttp2=1.57.0=h2d74bed_0
|
73 |
+
- libopenblas=0.3.21=h043d6bf_0
|
74 |
+
- libpng=1.6.39=h5eee18b_0
|
75 |
+
- libprotobuf=4.25.3=he621ea3_0
|
76 |
+
- libsodium=1.0.18=h7b6447c_0
|
77 |
+
- libssh2=1.11.1=h251f7ec_0
|
78 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
79 |
+
- libthrift=0.15.0=h1795dd8_2
|
80 |
+
- libtiff=4.5.1=h6a678d5_0
|
81 |
+
- libuuid=1.41.5=h5eee18b_0
|
82 |
+
- libwebp-base=1.3.2=h5eee18b_1
|
83 |
+
- lz4-c=1.9.4=h6a678d5_1
|
84 |
+
- mako=1.2.3=py310h06a4308_0
|
85 |
+
- matplotlib-base=3.9.2=py310hbfdbfaf_1
|
86 |
+
- mlflow=2.18.0=hff52083_0
|
87 |
+
- mlflow-skinny=2.18.0=py310hff52083_0
|
88 |
+
- mlflow-ui=2.18.0=py310hff52083_0
|
89 |
+
- multidict=6.1.0=py310h5eee18b_0
|
90 |
+
- ncurses=6.4=h6a678d5_0
|
91 |
+
- numexpr=2.10.1=py310hd28fd6d_0
|
92 |
+
- numpy=1.26.4=py310heeff2f4_0
|
93 |
+
- numpy-base=1.26.4=py310h8a23956_0
|
94 |
+
- openjpeg=2.5.2=he7f1fd0_0
|
95 |
+
- openssl=3.0.15=h5eee18b_0
|
96 |
+
- opentelemetry-api=1.16.0=pyhd8ed1ab_0
|
97 |
+
- opentelemetry-sdk=1.16.0=pyhd8ed1ab_0
|
98 |
+
- opentelemetry-semantic-conventions=0.37b0=pyhd8ed1ab_0
|
99 |
+
- orc=2.0.1=h2d29ad5_0
|
100 |
+
- paramiko=3.5.0=py310h06a4308_0
|
101 |
+
- pillow=11.0.0=py310hfdbf927_0
|
102 |
+
- prometheus_client=0.21.0=py310h06a4308_0
|
103 |
+
- prometheus_flask_exporter=0.22.4=py310h06a4308_0
|
104 |
+
- propcache=0.2.0=py310h5eee18b_0
|
105 |
+
- pyarrow=16.1.0=py310h1128e8f_0
|
106 |
+
- pynacl=1.5.0=py310h5eee18b_0
|
107 |
+
- pyopenssl=24.2.1=py310h06a4308_0
|
108 |
+
- pyparsing=3.2.0=py310h06a4308_0
|
109 |
+
- pysocks=1.7.1=py310h06a4308_0
|
110 |
+
- python=3.10.15=he870216_1
|
111 |
+
- python-dateutil=2.9.0post0=py310h06a4308_2
|
112 |
+
- python-tzdata=2023.3=pyhd3eb1b0_0
|
113 |
+
- python_abi=3.10=2_cp310
|
114 |
+
- pyyaml=6.0.2=py310h5eee18b_0
|
115 |
+
- querystring_parser=1.2.4=py310h06a4308_0
|
116 |
+
- re2=2022.04.01=h295c915_0
|
117 |
+
- readline=8.2=h5eee18b_0
|
118 |
+
- requests=2.32.3=py310h06a4308_1
|
119 |
+
- s2n=1.3.27=hdbd6064_0
|
120 |
+
- setuptools=75.1.0=py310h06a4308_0
|
121 |
+
- six=1.16.0=pyhd3eb1b0_1
|
122 |
+
- smmap=4.0.0=pyhd3eb1b0_0
|
123 |
+
- snappy=1.2.1=h6a678d5_0
|
124 |
+
- sqlalchemy=2.0.34=py310h00e1ef3_0
|
125 |
+
- sqlite=3.45.3=h5eee18b_0
|
126 |
+
- sqlparse=0.4.4=py310h06a4308_0
|
127 |
+
- threadpoolctl=3.5.0=py310h2f386ee_0
|
128 |
+
- tk=8.6.14=h39e8969_0
|
129 |
+
- typing_extensions=4.11.0=py310h06a4308_0
|
130 |
+
- unicodedata2=15.1.0=py310h5eee18b_0
|
131 |
+
- urllib3=2.2.3=py310h06a4308_0
|
132 |
+
- utf8proc=2.6.1=h5eee18b_1
|
133 |
+
- websocket-client=1.8.0=py310h06a4308_0
|
134 |
+
- wheel=0.44.0=py310h06a4308_0
|
135 |
+
- wrapt=1.14.1=py310h5eee18b_0
|
136 |
+
- xz=5.4.6=h5eee18b_1
|
137 |
+
- yaml=0.2.5=h7b6447c_0
|
138 |
+
- yarl=1.18.0=py310h5eee18b_0
|
139 |
+
- zipp=3.21.0=py310h06a4308_0
|
140 |
+
- zlib=1.2.13=h5eee18b_1
|
141 |
+
- zstd=1.5.6=hc292b87_0
|
142 |
+
- pip:
|
143 |
+
- absl-py==2.1.0
|
144 |
+
- aiohttp==3.11.7
|
145 |
+
- aiosignal==1.3.1
|
146 |
+
- annotated-types==0.7.0
|
147 |
+
- anyio==4.6.2.post1
|
148 |
+
- argon2-cffi==23.1.0
|
149 |
+
- argon2-cffi-bindings==21.2.0
|
150 |
+
- arrow==1.3.0
|
151 |
+
- asttokens==2.4.1
|
152 |
+
- async-lru==2.0.4
|
153 |
+
- async-timeout==5.0.1
|
154 |
+
- babel==2.16.0
|
155 |
+
- backoff==2.2.1
|
156 |
+
- beautifulsoup4==4.12.3
|
157 |
+
- bleach==6.2.0
|
158 |
+
- boto3==1.35.70
|
159 |
+
- botocore==1.35.70
|
160 |
+
- cachetools==5.5.0
|
161 |
+
- charset-normalizer==3.4.0
|
162 |
+
- comm==0.2.2
|
163 |
+
- cycler==0.12.1
|
164 |
+
- debugpy==1.8.9
|
165 |
+
- decorator==5.1.1
|
166 |
+
- defusedxml==0.7.1
|
167 |
+
- exceptiongroup==1.2.2
|
168 |
+
- executing==2.1.0
|
169 |
+
- fastapi==0.115.5
|
170 |
+
- fastjsonschema==2.20.0
|
171 |
+
- filelock==3.16.1
|
172 |
+
- fire==0.7.0
|
173 |
+
- fonttools==4.55.0
|
174 |
+
- fqdn==1.5.1
|
175 |
+
- fsspec==2024.10.0
|
176 |
+
- git-filter-repo==2.47.0
|
177 |
+
- google-auth==2.36.0
|
178 |
+
- google-auth-oauthlib==1.2.1
|
179 |
+
- grpcio==1.68.0
|
180 |
+
- h11==0.14.0
|
181 |
+
- httpcore==1.0.7
|
182 |
+
- httptools==0.6.4
|
183 |
+
- httpx==0.27.2
|
184 |
+
- idna==3.10
|
185 |
+
- imageio==2.36.1
|
186 |
+
- ipykernel==6.26.0
|
187 |
+
- ipython==8.17.2
|
188 |
+
- ipywidgets==8.1.1
|
189 |
+
- isoduration==20.11.0
|
190 |
+
- jedi==0.19.2
|
191 |
+
- jmespath==1.0.1
|
192 |
+
- json5==0.10.0
|
193 |
+
- jsonpointer==3.0.0
|
194 |
+
- jsonschema==4.23.0
|
195 |
+
- jsonschema-specifications==2024.10.1
|
196 |
+
- jupyter-client==8.6.3
|
197 |
+
- jupyter-core==5.7.2
|
198 |
+
- jupyter-events==0.10.0
|
199 |
+
- jupyter-lsp==2.2.5
|
200 |
+
- jupyter-server==2.14.2
|
201 |
+
- jupyter-server-terminals==0.5.3
|
202 |
+
- jupyterlab==4.2.0
|
203 |
+
- jupyterlab-pygments==0.3.0
|
204 |
+
- jupyterlab-server==2.27.3
|
205 |
+
- jupyterlab-widgets==3.0.13
|
206 |
+
- kiwisolver==1.4.7
|
207 |
+
- lazy-loader==0.4
|
208 |
+
- lightning==2.4.0
|
209 |
+
- lightning-cloud==0.5.70
|
210 |
+
- lightning-sdk==0.1.30
|
211 |
+
- lightning-utilities==0.11.9
|
212 |
+
- lime==0.2.0.1
|
213 |
+
- litdata==0.2.32
|
214 |
+
- litserve==0.2.5
|
215 |
+
- markdown==3.7
|
216 |
+
- markdown-it-py==3.0.0
|
217 |
+
- markupsafe==3.0.2
|
218 |
+
- matplotlib==3.8.2
|
219 |
+
- matplotlib-inline==0.1.7
|
220 |
+
- mdurl==0.1.2
|
221 |
+
- mistune==3.0.2
|
222 |
+
- mpmath==1.3.0
|
223 |
+
- nbclient==0.10.0
|
224 |
+
- nbconvert==7.16.4
|
225 |
+
- nbformat==5.10.4
|
226 |
+
- nest-asyncio==1.6.0
|
227 |
+
- networkx==3.4.2
|
228 |
+
- notebook-shim==0.2.4
|
229 |
+
- nvidia-cublas-cu12==12.1.3.1
|
230 |
+
- nvidia-cuda-cupti-cu12==12.1.105
|
231 |
+
- nvidia-cuda-nvrtc-cu12==12.1.105
|
232 |
+
- nvidia-cuda-runtime-cu12==12.1.105
|
233 |
+
- nvidia-cudnn-cu12==8.9.2.26
|
234 |
+
- nvidia-cufft-cu12==11.0.2.54
|
235 |
+
- nvidia-curand-cu12==10.3.2.106
|
236 |
+
- nvidia-cusolver-cu12==11.4.5.107
|
237 |
+
- nvidia-cusparse-cu12==12.1.0.106
|
238 |
+
- nvidia-nccl-cu12==2.19.3
|
239 |
+
- nvidia-nvjitlink-cu12==12.6.85
|
240 |
+
- nvidia-nvtx-cu12==12.1.105
|
241 |
+
- oauthlib==3.2.2
|
242 |
+
- overrides==7.7.0
|
243 |
+
- packaging==24.2
|
244 |
+
- pandas==2.1.4
|
245 |
+
- pandocfilters==1.5.1
|
246 |
+
- parso==0.8.4
|
247 |
+
- pexpect==4.9.0
|
248 |
+
- pip==24.3.1
|
249 |
+
- platformdirs==4.3.6
|
250 |
+
- prompt-toolkit==3.0.48
|
251 |
+
- protobuf==4.23.4
|
252 |
+
- psutil==6.1.0
|
253 |
+
- ptyprocess==0.7.0
|
254 |
+
- pure-eval==0.2.3
|
255 |
+
- pyasn1==0.6.1
|
256 |
+
- pyasn1-modules==0.4.1
|
257 |
+
- pycparser==2.22
|
258 |
+
- pydantic==2.10.2
|
259 |
+
- pydantic-core==2.27.1
|
260 |
+
- pygments==2.18.0
|
261 |
+
- pyjwt==2.10.0
|
262 |
+
- python-dotenv==1.0.1
|
263 |
+
- python-json-logger==2.0.7
|
264 |
+
- python-multipart==0.0.17
|
265 |
+
- pytorch-lightning==2.4.0
|
266 |
+
- pytz==2024.2
|
267 |
+
- pyzmq==26.2.0
|
268 |
+
- referencing==0.35.1
|
269 |
+
- requests-oauthlib==2.0.0
|
270 |
+
- rfc3339-validator==0.1.4
|
271 |
+
- rfc3986-validator==0.1.1
|
272 |
+
- rich==13.9.4
|
273 |
+
- rpds-py==0.21.0
|
274 |
+
- rsa==4.9
|
275 |
+
- s3transfer==0.10.4
|
276 |
+
- scikit-image==0.24.0
|
277 |
+
- scikit-learn==1.3.2
|
278 |
+
- scipy==1.11.4
|
279 |
+
- send2trash==1.8.3
|
280 |
+
- simple-term-menu==1.6.5
|
281 |
+
- sniffio==1.3.1
|
282 |
+
- soupsieve==2.6
|
283 |
+
- stack-data==0.6.3
|
284 |
+
- starlette==0.41.3
|
285 |
+
- sympy==1.13.3
|
286 |
+
- tensorboard==2.15.1
|
287 |
+
- tensorboard-data-server==0.7.2
|
288 |
+
- termcolor==2.5.0
|
289 |
+
- terminado==0.18.1
|
290 |
+
- tifffile==2024.9.20
|
291 |
+
- tinycss2==1.4.0
|
292 |
+
- tomli==2.1.0
|
293 |
+
- torch==2.2.1+cu121
|
294 |
+
- torchmetrics==1.3.1
|
295 |
+
- torchvision==0.17.1+cu121
|
296 |
+
- tornado==6.4.2
|
297 |
+
- tqdm==4.67.1
|
298 |
+
- traitlets==5.14.3
|
299 |
+
- triton==2.2.0
|
300 |
+
- types-python-dateutil==2.9.0.20241003
|
301 |
+
- typing-extensions==4.12.2
|
302 |
+
- tzdata==2024.2
|
303 |
+
- uri-template==1.3.0
|
304 |
+
- uvicorn==0.32.1
|
305 |
+
- uvloop==0.21.0
|
306 |
+
- watchfiles==1.0.0
|
307 |
+
- wcwidth==0.2.13
|
308 |
+
- webcolors==24.11.1
|
309 |
+
- webencodings==0.5.1
|
310 |
+
- websockets==14.1
|
311 |
+
- werkzeug==3.1.3
|
312 |
+
- widgetsnbextension==4.0.13
|
313 |
+
prefix: /home/zeus/miniconda3/envs/cloudspace
|
get_data.ps1
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Download the file using Invoke-WebRequest
|
2 |
+
$destination = "./pokemonclassification.zip"
|
3 |
+
$url = "https://www.kaggle.com/api/v1/datasets/download/lantian773030/pokemonclassification"
|
4 |
+
Invoke-WebRequest -Uri $url -OutFile $destination
|
5 |
+
|
6 |
+
# Extract the zip file to the specified folder
|
7 |
+
$extractPath = "./pokemonclassification"
|
8 |
+
Expand-Archive -Path $destination -DestinationPath $extractPath
|
9 |
+
|
10 |
+
# Remove the downloaded zip file (if not needed anymore)
|
11 |
+
Remove-Item $destination
|
get_data.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
curl -L -o ./pokemonclassification.zip https://www.kaggle.com/api/v1/datasets/download/lantian773030/pokemonclassification
|
3 |
+
|
4 |
+
unzip -d ./pokemonclassification pokemonclassification.zip
|
5 |
+
|
6 |
+
# Remove the downloaded zip file (if you don't need it anymore)
|
7 |
+
rm ./pokemonclassification.zip
|
get_samples.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.inference_utils import find_images_from_path
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
from utils.train_utils import initialize_model
|
5 |
+
|
6 |
+
|
7 |
+
def main():
|
8 |
+
parser = argparse.ArgumentParser(description="Image Inference")
|
9 |
+
parser.add_argument(
|
10 |
+
"--model_name",
|
11 |
+
type=str,
|
12 |
+
help="Model name (resnet, alexnet, vgg, squeezenet, densenet)",
|
13 |
+
default="resnet",
|
14 |
+
)
|
15 |
+
parser.add_argument(
|
16 |
+
"--model_weights",
|
17 |
+
type=str,
|
18 |
+
help="Path to the model weights",
|
19 |
+
default="./trained_models/pokemon_resnet.pth",
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
"--image_path",
|
23 |
+
type=str,
|
24 |
+
help="Path to the image",
|
25 |
+
default="./pokemonclassification/PokemonData/",
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
"--num_classes", type=int, help="Number of classes", default=150
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--label", type=str, help="Label to filter the images", default='Dragonair' # Krabby, Clefairy
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--num_correct", type=int, help="Number of correctly classified images", default=5
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--num_incorrect", type=int, help="Number of incorrectly classified images", default=5
|
38 |
+
)
|
39 |
+
|
40 |
+
args = parser.parse_args()
|
41 |
+
|
42 |
+
assert (args.model_name == "resnet"), "Only the ResNet is supported model for now"
|
43 |
+
|
44 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
|
46 |
+
# Initialize the model
|
47 |
+
model = initialize_model(args.model_name, args.num_classes)
|
48 |
+
model = model.to(device)
|
49 |
+
|
50 |
+
# Load the model weights
|
51 |
+
model.load_state_dict(torch.load(args.model_weights, map_location=device))
|
52 |
+
find_images_from_path(args.image_path, model, device, args.num_correct, args.num_incorrect, args.label)
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
main()
|
indices_60_32.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:32b56617770be9430d034b41ef9235bb938cd1db40a78bb15a7d579229c79511
|
3 |
+
size 20240
|
inference.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import argparse
|
3 |
+
from utils.inference_utils import preprocess_image, predict
|
4 |
+
from utils.train_utils import initialize_model
|
5 |
+
from utils.interpretability import lime_interpret_image_inference
|
6 |
+
from utils.data import CLASS_NAMES
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
parser = argparse.ArgumentParser(description="Image Inference")
|
11 |
+
parser.add_argument(
|
12 |
+
"--model_name",
|
13 |
+
type=str,
|
14 |
+
help="Model name (resnet, alexnet, vgg, squeezenet, densenet)",
|
15 |
+
default="resnet",
|
16 |
+
)
|
17 |
+
parser.add_argument(
|
18 |
+
"--model_weights",
|
19 |
+
type=str,
|
20 |
+
help="Path to the model weights",
|
21 |
+
default="./trained_models/pokemon_resnet.pth",
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--image_path",
|
25 |
+
type=str,
|
26 |
+
help="Path to the image",
|
27 |
+
default="./pokemonclassification/PokemonData/Chansey/57ccf27cba024fac9531baa9f619ec62.jpg",
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--num_classes", type=int, help="Number of classes", default=150
|
31 |
+
)
|
32 |
+
parser.add_argument(
|
33 |
+
"--lime_interpretability",
|
34 |
+
action="store_true",
|
35 |
+
help="Whether to run interpretability or not",
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--classify",
|
39 |
+
action="store_true",
|
40 |
+
help="Whether to classify the image when saving the lime filter")
|
41 |
+
|
42 |
+
args = parser.parse_args()
|
43 |
+
|
44 |
+
if args.lime_interpretability:
|
45 |
+
assert (
|
46 |
+
args.model_name == "resnet"
|
47 |
+
), "Interpretability is only supported for ResNet model for now"
|
48 |
+
|
49 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
50 |
+
|
51 |
+
# Initialize the model
|
52 |
+
model = initialize_model(args.model_name, args.num_classes)
|
53 |
+
model = model.to(device)
|
54 |
+
|
55 |
+
# Load the model weights
|
56 |
+
model.load_state_dict(torch.load(args.model_weights, map_location=device))
|
57 |
+
|
58 |
+
# Preprocess the image
|
59 |
+
image = preprocess_image(args.image_path, (224, 224)).to(device)
|
60 |
+
|
61 |
+
# Perform inference
|
62 |
+
preds = torch.max(predict(model, image), 1)[1]
|
63 |
+
print(f"Predicted class: {CLASS_NAMES[preds.item()]}")
|
64 |
+
|
65 |
+
if args.lime_interpretability:
|
66 |
+
lime_interpret_image_inference(args, model, image, device)
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
main()
|
train.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torchvision import transforms
|
3 |
+
from utils.data import PokemonDataModule
|
4 |
+
from utils.train import initialize_model, train_and_evaluate
|
5 |
+
import torch
|
6 |
+
import torch.optim as optim
|
7 |
+
import mlflow
|
8 |
+
import argparse
|
9 |
+
import random
|
10 |
+
|
11 |
+
# The shape of the images that the models expects
|
12 |
+
IMG_SHAPE = (224, 224)
|
13 |
+
|
14 |
+
|
15 |
+
def parser_args():
|
16 |
+
parser = argparse.ArgumentParser(description="Pokemon Classification")
|
17 |
+
parser.add_argument(
|
18 |
+
"--data_dir",
|
19 |
+
type=str,
|
20 |
+
default="./pokemonclassification/PokemonData",
|
21 |
+
help="Path to the data directory",
|
22 |
+
)
|
23 |
+
parser.add_argument(
|
24 |
+
"--indices_file",
|
25 |
+
type=str,
|
26 |
+
default="indices_60_32.pkl",
|
27 |
+
help="Path to the indices file",
|
28 |
+
)
|
29 |
+
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
|
30 |
+
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
|
31 |
+
parser.add_argument(
|
32 |
+
"--train_batch_size", type=int, default=128, help="train Batch size"
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--test_batch_size", type=int, default=512, help="test Batch size"
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
"--model",
|
39 |
+
type=str,
|
40 |
+
choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"],
|
41 |
+
default="resnet",
|
42 |
+
help="Model to be used",
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--feature_extract",
|
46 |
+
type=bool,
|
47 |
+
default=True,
|
48 |
+
help="whether to freeze the backbone or not",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--use_pretrained",
|
52 |
+
type=bool,
|
53 |
+
default=True,
|
54 |
+
help="whether to use pretrained model or not",
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--experiment_id",
|
58 |
+
type=int,
|
59 |
+
default=0,
|
60 |
+
help="Experiment ID to log the results",
|
61 |
+
)
|
62 |
+
return parser.parse_args()
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
args = parser_args()
|
67 |
+
|
68 |
+
pokemon_dataset = PokemonDataModule(args.data_dir)
|
69 |
+
NUM_CLASSES = len(pokemon_dataset.class_names)
|
70 |
+
|
71 |
+
# Get class names
|
72 |
+
print(f"Number of classes: {NUM_CLASSES}")
|
73 |
+
|
74 |
+
# You can only the use precomputed means and vars if using the same indices file ('indices_60_32.pkl')
|
75 |
+
if "indices_60_32.pkl" in args.indices_file:
|
76 |
+
chanel_means = torch.tensor([0.6062, 0.5889, 0.5550])
|
77 |
+
chanel_vars = torch.tensor([0.3284, 0.3115, 0.3266])
|
78 |
+
stats = {"mean": chanel_means, "std": chanel_vars}
|
79 |
+
_ = pokemon_dataset.prepare_data(
|
80 |
+
indices_file=args.indices_file, get_stats=False
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
stats = pokemon_dataset.prepare_data(
|
84 |
+
indices_file=args.indices_file, get_stats=True
|
85 |
+
)
|
86 |
+
|
87 |
+
print(f"Train dataset size: {len(pokemon_dataset.train_dataset)}")
|
88 |
+
print(f"Test dataset size: {len(pokemon_dataset.test_dataset)}")
|
89 |
+
|
90 |
+
# Transformations of data for testing
|
91 |
+
test_transform = transforms.Compose(
|
92 |
+
[
|
93 |
+
transforms.Resize(IMG_SHAPE),
|
94 |
+
transforms.ToTensor(), # Convert PIL images to tensors
|
95 |
+
transforms.Normalize(**stats), # Normalize images using mean and std
|
96 |
+
]
|
97 |
+
)
|
98 |
+
|
99 |
+
# Data augmentations for training
|
100 |
+
train_transform = transforms.Compose(
|
101 |
+
[
|
102 |
+
transforms.Resize(IMG_SHAPE),
|
103 |
+
transforms.RandomRotation(10),
|
104 |
+
transforms.RandomHorizontalFlip(),
|
105 |
+
transforms.RandomCrop(IMG_SHAPE, padding=4),
|
106 |
+
transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
|
107 |
+
transforms.ColorJitter(
|
108 |
+
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2
|
109 |
+
),
|
110 |
+
transforms.RandomGrayscale(p=0.2),
|
111 |
+
transforms.ToTensor(),
|
112 |
+
transforms.Normalize(**stats),
|
113 |
+
]
|
114 |
+
)
|
115 |
+
|
116 |
+
# get dataloaders
|
117 |
+
trainloader, testloader = pokemon_dataset.get_dataloaders(
|
118 |
+
train_transform=train_transform,
|
119 |
+
test_transform=test_transform,
|
120 |
+
train_batch_size=args.train_batch_size,
|
121 |
+
test_batch_size=args.test_batch_size,
|
122 |
+
)
|
123 |
+
|
124 |
+
pokemon_dataset.plot_examples(testloader, stats=stats)
|
125 |
+
|
126 |
+
pokemon_dataset.plot_examples(trainloader, stats=stats)
|
127 |
+
|
128 |
+
# Try with a finetuning a resnet for example
|
129 |
+
model = initialize_model(
|
130 |
+
args.model,
|
131 |
+
NUM_CLASSES,
|
132 |
+
feature_extract=args.feature_extract,
|
133 |
+
use_pretrained=args.use_pretrained,
|
134 |
+
)
|
135 |
+
|
136 |
+
# Print the model we just instantiated
|
137 |
+
print(model)
|
138 |
+
|
139 |
+
# Model, criterion, optimizer
|
140 |
+
criterion = nn.CrossEntropyLoss()
|
141 |
+
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
142 |
+
|
143 |
+
# Device configuration
|
144 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
145 |
+
|
146 |
+
with mlflow.start_run(
|
147 |
+
experiment_id=args.experiment_id,
|
148 |
+
run_name=f"{args.model}_{'finetuning' if not args.feature_extract else 'feature_extracting'}"
|
149 |
+
f"_{'pretrained' if args.use_pretrained else 'not_pretrained'}"
|
150 |
+
f"_{args.indices_file}_{random.randint(0, 1000)}",
|
151 |
+
) as run:
|
152 |
+
mlflow.log_param("epochs", args.epochs)
|
153 |
+
mlflow.log_param("lr", args.lr)
|
154 |
+
mlflow.log_param("train_batch_size", args.train_batch_size)
|
155 |
+
mlflow.log_param("test_batch_size", args.test_batch_size)
|
156 |
+
mlflow.log_param("model", args.model)
|
157 |
+
mlflow.log_param("feature_extract", args.feature_extract)
|
158 |
+
mlflow.log_param("use_pretrained", args.use_pretrained)
|
159 |
+
|
160 |
+
# Train and evaluate
|
161 |
+
history = train_and_evaluate(
|
162 |
+
model=model,
|
163 |
+
trainloader=trainloader,
|
164 |
+
testloader=testloader,
|
165 |
+
criterion=criterion,
|
166 |
+
optimizer=optimizer,
|
167 |
+
device=device,
|
168 |
+
epochs=args.epochs,
|
169 |
+
use_mlflow=True,
|
170 |
+
)
|
171 |
+
# Save the model
|
172 |
+
torch.save(model.state_dict(), f"pokemon_{args.model}.pth")
|
173 |
+
mlflow.log_artifact(f"pokemon_{args.model}.pth")
|
174 |
+
mlflow.end_run()
|