File size: 5,783 Bytes
ed320fc 12a36d6 ed320fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
---
license: mit
base_model:
- microsoft/resnet-18
pipeline_tag: image-classification
library_name: pytorch
---
# PokemonClassification
This repository explores the training of different models for a vision classification task, with a special focus made on reproducibility, and an attempt to a local interpretability of the decision made by a resnet model using LIME
## Table of Contents
- [PokemonClassification](#pokemonclassification)
- [Table of Contents](#table-of-contents)
- [Installation](#installation)
- [Dataset](#dataset)
- [Training](#training)
- [Inference](#inference)
- [Generating Data Samples](#generating-data-samples)
- [Interpretability](#interpretability)
- [Contributing](#contributing)
## Installation
1. Clone the repository:
```sh
git clone https://github.com/yourusername/PokemonClassification.git
cd PokemonClassification
```
2. Create a conda environment and activate it:
```sh
conda env create -f environment.yaml
conda activate pokemonclassification
```
## Dataset
To get the data, use the appropriate script based on your operating system:
- On Linux-based systems:
```shell
./get_data.sh
```
- On Windows:
```shell
./get_data.ps1
```
## Training
To train a model, use the `train.py` script. Here are the parameters you can specify:
```python
def parser_args():
parser = argparse.ArgumentParser(description="Pokemon Classification")
parser.add_argument("--data_dir", type=str, default="./pokemonclassification/PokemonData", help="Path to the data directory")
parser.add_argument("--indices_file", type=str, default="indices_60_32.pkl", help="Path to the indices file")
parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument("--train_batch_size", type=int, default=128, help="train Batch size")
parser.add_argument("--test_batch_size", type=int, default=512, help="test Batch size")
parser.add_argument("--model", type=str, choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"], default="resnet", help="Model to be used")
parser.add_argument("--feature_extract", type=bool, default=True, help="whether to freeze the backbone or not")
parser.add_argument("--use_pretrained", type=bool, default=True, help="whether to use pretrained model or not")
parser.add_argument("--experiment_id", type=int, default=0, help="Experiment ID to log the results")
return parser.parse_args()
```
Example:
```shell
python train.py--model resnet --data_dir data/PokemonData --epochs 10 --train_batch_size 32 --test_batch_size 32
```
## Inference
To perform inference on a single image, use the `inference.py` script. Here are the parameters you can specify:
```python
def main():
parser = argparse.ArgumentParser(description="Image Inference")
parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/Chansey/57ccf27cba024fac9531baa9f619ec62.jpg")
parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
parser.add_argument("--lime_interpretability", action="store_true", help="Whether to run interpretability or not")
parser.add_argument("--classify", action="store_true", help="Whether to classify the image when saving the lime filter")
args = parser.parse_args()
if args.lime_interpretability:
assert args.model_name == "resnet", "Interpretability is only supported for ResNet model for now"
```
Example:
```shell
python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10
```
## Generating Data Samples
To generate data samples, use the `get_samples.py` script. Here are the parameters you can specify:
```python
def main():
parser = argparse.ArgumentParser(description="Generate Data Samples")
parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/")
parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
parser.add_argument("--label", type=str, help="Label to filter the images", default='Dragonair')
parser.add_argument("--num_correct", type=int, help="Number of correctly classified images", default=5)
parser.add_argument("--num_incorrect", type=int, help="Number of incorrectly classified images", default=5)
args = parser.parse_args()
```
Example:
```shell
python get_samples.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image_directory --num_classes 10 --label Pikachu --num_correct 5 --num_incorrect 5
```
## Interpretability
To interpret the model's predictions using LIME, use the `inference.py` script with the `--lime_interpretability` flag.
Example:
```shell
python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10 --lime_interpretability
```
## Contributing
Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes. |