Sifal's picture
Update README.md
12a36d6 verified
metadata
license: mit
base_model:
  - microsoft/resnet-18
pipeline_tag: image-classification
library_name: pytorch

PokemonClassification

This repository explores the training of different models for a vision classification task, with a special focus made on reproducibility, and an attempt to a local interpretability of the decision made by a resnet model using LIME

Table of Contents

Installation

  1. Clone the repository:

    git clone https://github.com/yourusername/PokemonClassification.git
    cd PokemonClassification
    
  2. Create a conda environment and activate it:

    conda env create -f environment.yaml
    conda activate pokemonclassification
    

Dataset

To get the data, use the appropriate script based on your operating system:

  • On Linux-based systems:

    ./get_data.sh
    
  • On Windows:

    ./get_data.ps1
    

Training

To train a model, use the train.py script. Here are the parameters you can specify:

def parser_args():
    parser = argparse.ArgumentParser(description="Pokemon Classification")
    parser.add_argument("--data_dir", type=str, default="./pokemonclassification/PokemonData", help="Path to the data directory")
    parser.add_argument("--indices_file", type=str, default="indices_60_32.pkl", help="Path to the indices file")
    parser.add_argument("--epochs", type=int, default=20, help="Number of epochs")
    parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
    parser.add_argument("--train_batch_size", type=int, default=128, help="train Batch size")
    parser.add_argument("--test_batch_size", type=int, default=512, help="test Batch size")
    parser.add_argument("--model", type=str, choices=["resnet", "alexnet", "vgg", "squeezenet", "densenet"], default="resnet", help="Model to be used")
    parser.add_argument("--feature_extract", type=bool, default=True, help="whether to freeze the backbone or not")
    parser.add_argument("--use_pretrained", type=bool, default=True, help="whether to use pretrained model or not")
    parser.add_argument("--experiment_id", type=int, default=0, help="Experiment ID to log the results")
    return parser.parse_args()

Example:

python train.py--model resnet --data_dir data/PokemonData --epochs 10 --train_batch_size 32 --test_batch_size 32

Inference

To perform inference on a single image, use the inference.py script. Here are the parameters you can specify:

def main():
    parser = argparse.ArgumentParser(description="Image Inference")
    parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
    parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
    parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/Chansey/57ccf27cba024fac9531baa9f619ec62.jpg")
    parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
    parser.add_argument("--lime_interpretability", action="store_true", help="Whether to run interpretability or not")
    parser.add_argument("--classify", action="store_true", help="Whether to classify the image when saving the lime filter")
    args = parser.parse_args()

    if args.lime_interpretability:
        assert args.model_name == "resnet", "Interpretability is only supported for ResNet model for now"

Example:

python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10

Generating Data Samples

To generate data samples, use the get_samples.py script. Here are the parameters you can specify:

def main():
    parser = argparse.ArgumentParser(description="Generate Data Samples")
    parser.add_argument("--model_name", type=str, help="Model name (resnet, alexnet, vgg, squeezenet, densenet)", default="resnet")
    parser.add_argument("--model_weights", type=str, help="Path to the model weights", default="./trained_models/pokemon_resnet.pth")
    parser.add_argument("--image_path", type=str, help="Path to the image", default="./pokemonclassification/PokemonData/")
    parser.add_argument("--num_classes", type=int, help="Number of classes", default=150)
    parser.add_argument("--label", type=str, help="Label to filter the images", default='Dragonair')
    parser.add_argument("--num_correct", type=int, help="Number of correctly classified images", default=5)
    parser.add_argument("--num_incorrect", type=int, help="Number of incorrectly classified images", default=5)
    args = parser.parse_args()

Example:

python get_samples.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image_directory --num_classes 10 --label Pikachu --num_correct 5 --num_incorrect 5

Interpretability

To interpret the model's predictions using LIME, use the inference.py script with the --lime_interpretability flag.

Example:

python inference.py --model_name resnet --model_weights path_to_your_model_weights.pth --image_path path_to_your_image.jpg --num_classes 10 --lime_interpretability

Contributing

Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.