--- license: mit base_model: - microsoft/resnet-18 pipeline_tag: image-classification library_name: pytorch --- # PokemonClassification This repository explores the training of different models for a vision classification task, with a special focus made on reproducibility, and an attempt to a local interpretability of the decision made by a resnet model using LIME ## Table of Contents - [PokemonClassification](#pokemonclassification) - [Table of Contents](#table-of-contents) - [Installation](#installation) - [Dataset](#dataset) - [Training](#training) - [Inference](#inference) - [Generating Data Samples](#generating-data-samples) - [Interpretability](#interpretability) - [Contributing](#contributing) ## Installation 1. Clone the repository: ```sh git clone https://github.com/yourusername/PokemonClassification.git cd PokemonClassification ``` 2. Create a conda environment and activate it: ```sh conda env create -f environment.yaml conda activate pokemonclassification ``` ## Dataset To get the data, use the appropriate script based on your operating system: - On Linux-based systems: ```shell ./get_data.sh ``` - On Windows: ```shell ./get_data.ps1 ``` ## Training To train a model, use the `train.py` script. Here are the parameters you can specify: ```python def parser_args(): parser = argparse.ArgumentParser(description="Pokemon Classification") parser.add_argument("--data_dir", type=str, default="./pokemonclassification/PokemonData", help="Path to the data directory") parser.add_argument("--indices_file", type=str, default="indices_60_32.pkl", help="Path to the indices file") parser.add_argument("--epochs", type=int, default=20, help="Number of epochs") parser.add_argument("--lr", type=float, default=0.001, help="Learning rate") parser.add_argument("--train_batch_size", type=int, default=128, help="train Batch size") parser.add_argument("--test_batch_size", type=int, default=512, help="test Batch size") parser.add_argument("--model", type=str, choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"], default="resnet", help="Model to be used") parser.add_argument("--feature_extract", type=bool, default=True, help="whether to freeze the backbone or not") parser.add_argument("--use_pretrained", type=bool, default=True, help="whether to use pretrained model or not") parser.add_argument("--experiment_id", type=int, default=0, help="Experiment ID to log the results") return parser.parse_args() ``` Example: ```shell python train.py--model resnet --data_dir data/PokemonData --epochs 10 --train_batch_size 32 --test_batch_size 32 ``` ## Inference To perform inference on a single image, use the `inference.py` script. Here are the parameters you can specify: ```python def main(): parser = argparse.ArgumentParser(description="Image Inference") parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet") parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth") parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/Chansey/57ccf27cba024fac9531baa9f619ec62.jpg") parser.add_argument("--num_classes", type=int, help="Number of classes", default=150) parser.add_argument("--lime_interpretability", action="store_true", help="Whether to run interpretability or not") parser.add_argument("--classify", action="store_true", help="Whether to classify the image when saving the lime filter") args = parser.parse_args() if args.lime_interpretability: assert args.model_name == "resnet", "Interpretability is only supported for ResNet model for now" ``` Example: ```shell python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10 ``` ## Generating Data Samples To generate data samples, use the `get_samples.py` script. Here are the parameters you can specify: ```python def main(): parser = argparse.ArgumentParser(description="Generate Data Samples") parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet") parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth") parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/") parser.add_argument("--num_classes", type=int, help="Number of classes", default=150) parser.add_argument("--label", type=str, help="Label to filter the images", default='Dragonair') parser.add_argument("--num_correct", type=int, help="Number of correctly classified images", default=5) parser.add_argument("--num_incorrect", type=int, help="Number of incorrectly classified images", default=5) args = parser.parse_args() ``` Example: ```shell python get_samples.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image_directory --num_classes 10 --label Pikachu --num_correct 5 --num_incorrect 5 ``` ## Interpretability To interpret the model's predictions using LIME, use the `inference.py` script with the `--lime_interpretability` flag. Example: ```shell python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10 --lime_interpretability ``` ## Contributing Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.