MedMNIST Active Learning Model

Overview

This model is designed for image classification tasks within the medical imaging domain, specifically targeting the MedMNIST dataset. It employs a ResNet-50 architecture tailored for 28x28 pixel images and incorporates active learning strategies to enhance performance with limited labeled data.

Model Architecture

  • Base Model: ResNet-50
  • Modifications:
    • Adjusted initial convolution layer to accommodate 28x28 input images.
    • Removed max pooling layer to preserve spatial dimensions.
    • Customized fully connected layer to output predictions for 9 classes.

Training Procedure

Training Hyperparameters

Hyperparameter Value
Batch Size 53
Initial Labeled Size 3559
Learning Rate 0.01332344940133225
MC Dropout Passes 6
Samples to Label 4430
Weight Decay 0.00021921795989143406

Optimizer Settings

The optimizer used during training was Stochastic Gradient Descent(SDG), with the following settings and a Learning Rate Scheduler of ReduceLROnPlateau:

  • learning_rate = 0.01332344940133225
  • momentum = 0.9
  • weight_decay = 0.00021921795989143406

The model was trained with float32 precision.

Dataset

PathMNIST

Data Augmentation

  • Random resized cropping
  • Horizontal flipping
  • Random rotations
  • Color jittering
  • Gaussian blur
  • RandAugment

Active Learning Strategy

The active learning process was based on a mixed sampling strategy:

  • Uncertainty Sampling: Monte Carlo (MC) dropout was used to estimate uncertainty.
  • Diversity Sampling: K-means clustering was employed to ensure diverse samples.

Evaluation

The model was evaluated on the validation set of PathMNIST. Key performance metrics include:

  • Accuracy: 94.72%
  • Loss: 0.2397
  • AUC: 99.73%

Graphs

The following plots illustrates the validation loss, validation accuracy, and validation auc over batches(number of iterations over the dataset) during the active learning process.

  • Validation Loss Validation Loss
  • Validation Accuracy Validation Accuracy
  • Validation AUC Validation AUC

Usage

All code for this model can be accessed in the following GitHub Repository: Allen Cheung Determined_AI_Hackathon

To utilize this model:

  1. Install Dependencies: Ensure the following Python packages are installed:

    • torch
    • torchvision
    • medmnist
    • scikit-learn
    • determined

    Install them using pip:

    pip install torch torchvision medmnist scikit-learn determined
    
  2. Load the Model:

    import torch
    from model import ResNet50_28
    
    model = ResNet50_28(num_classes=9)
    model.load_state_dict(torch.load('pytorch_model.bin'))
    model.eval()
    
  3. Inference:

    from torchvision import transforms
    from PIL import Image
    
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    image = Image.open('path_to_image.jpg')
    input_tensor = transform(image).unsqueeze(0)
    output = model(input_tensor)
    prediction = output.argmax(dim=1).item()
    print(f"Predicted class: {prediction}")
    

License

This project is licensed under the MIT License.

Acknowledgements

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.