File size: 5,783 Bytes
ed320fc
 
 
 
 
12a36d6
ed320fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
---
license: mit
base_model:
- microsoft/resnet-18
pipeline_tag: image-classification
library_name: pytorch
---
# PokemonClassification

This repository explores the training of different models for a vision classification task, with a special focus made on reproducibility, and an attempt to a local interpretability of the decision made by a resnet model using LIME

## Table of Contents

- [PokemonClassification](#pokemonclassification)
  - [Table of Contents](#table-of-contents)
  - [Installation](#installation)
  - [Dataset](#dataset)
  - [Training](#training)
  - [Inference](#inference)
  - [Generating Data Samples](#generating-data-samples)
  - [Interpretability](#interpretability)
  - [Contributing](#contributing)

## Installation

1. Clone the repository:

    ```sh
    git clone https://github.com/yourusername/PokemonClassification.git
    cd PokemonClassification
    ```

2. Create a conda environment and activate it:

    ```sh
    conda env create -f environment.yaml
    conda activate pokemonclassification
    ```

## Dataset

To get the data, use the appropriate script based on your operating system:

- On Linux-based systems:
  
    ```shell
    ./get_data.sh
    ```

- On Windows:
  
    ```shell
    ./get_data.ps1
    ```

## Training

To train a model, use the `train.py` script. Here are the parameters you can specify:

```python
def parser_args():
    parser = argparse.ArgumentParser(description="Pokemon Classification")
    parser.add_argument("--data_dir", type=str, default="./pokemonclassification/PokemonData", help="Path to the data directory")
    parser.add_argument("--indices_file", type=str, default="indices_60_32.pkl", help="Path to the indices file")
    parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
    parser.add_argument("--train_batch_size", type=int, default=128, help="train Batch size")
    parser.add_argument("--test_batch_size", type=int, default=512, help="test Batch size")
    parser.add_argument("--model", type=str, choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"], default="resnet", help="Model to be used")
    parser.add_argument("--feature_extract", type=bool, default=True, help="whether to freeze the backbone or not")
    parser.add_argument("--use_pretrained", type=bool, default=True, help="whether to use pretrained model or not")
    parser.add_argument("--experiment_id", type=int, default=0, help="Experiment ID to log the results")
    return parser.parse_args()
```

Example:

```shell
python train.py--model resnet --data_dir data/PokemonData --epochs 10 --train_batch_size 32 --test_batch_size 32
```

## Inference

To perform inference on a single image, use the `inference.py` script. Here are the parameters you can specify:

```python
def main():
    parser = argparse.ArgumentParser(description="Image Inference")
    parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
    parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
    parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/Chansey/57ccf27cba024fac9531baa9f619ec62.jpg")
    parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
    parser.add_argument("--lime_interpretability", action="store_true", help="Whether to run interpretability or not")
    parser.add_argument("--classify", action="store_true", help="Whether to classify the image when saving the lime filter")
    args = parser.parse_args()

    if args.lime_interpretability:
        assert args.model_name == "resnet", "Interpretability is only supported for ResNet model for now"
```

Example:

```shell
python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10
```

## Generating Data Samples

To generate data samples, use the `get_samples.py` script. Here are the parameters you can specify:

```python
def main():
    parser = argparse.ArgumentParser(description="Generate Data Samples")
    parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
    parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
    parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/")
    parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
    parser.add_argument("--label", type=str, help="Label to filter the images", default='Dragonair')
    parser.add_argument("--num_correct", type=int, help="Number of correctly classified images", default=5)
    parser.add_argument("--num_incorrect", type=int, help="Number of incorrectly classified images", default=5)
    args = parser.parse_args()
```

Example:

```shell
python get_samples.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image_directory --num_classes 10 --label Pikachu --num_correct 5 --num_incorrect 5
```

## Interpretability

To interpret the model's predictions using LIME, use the `inference.py` script with the `--lime_interpretability` flag.

Example:

```shell
python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10 --lime_interpretability
```

## Contributing

Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.