ResNet18 Pneumonia Detection Model

This model is a fine-tuned version of the ResNet18 architecture for pneumonia detection. It was trained on the Kaggle Chest X-ray Pneumonia dataset, which includes images of normal lungs and lungs with pneumonia. The model is capable of distinguishing between Pneumonia and Normal chest X-rays.

Model Details

  • Model Architecture: ResNet18
  • Input Size: 224 x 224
  • Number of Classes: 2 (Pneumonia, Normal)
  • Framework: PyTorch
  • Training Dataset: Kaggle Chest X-ray Pneumonia Dataset
  • Library: PyTorch

Model Performance

  • Accuracy: 83.3%
  • Loss: 0.2459

Intended Use

This model is designed to assist healthcare professionals in identifying pneumonia from chest X-ray images. It should not be used as a sole diagnostic tool but as a supplement to medical expertise.

Training Details

The model was trained using the following setup:

  • Architecture: ResNet18 (Pre-trained on ImageNet)
  • Optimizer: SGD (Stochastic Gradient Descent)
    • Learning Rate: 0.001
    • Momentum: 0.9
  • Loss Function: CrossEntropyLoss
  • Batch Size: 32
  • Data Augmentation:
    • Random Rotation (±30 degrees)
    • Random Zoom (20%)
    • Random Horizontal Shift (±10% width)
    • Random Vertical Shift (±10% height)
    • Random Horizontal Flip
  • Training Epochs: 1
  • Evaluation Metric: Cross Entropy Loss

Augmentation Details

The dataset was augmented during training with the following transformations:

  • Randomly rotated some training images by 30 degrees
  • Randomly zoomed some training images by 20%
  • Randomly shifted images horizontally by 10% of the width
  • Randomly shifted images vertically by 10% of the height
  • Randomly flipped images horizontally

How to Use the Model

You can use this model with the transformers and torch libraries.

import torch
from huggingface_hub import hf_hub_download
from torchvision import transforms
from PIL import Image
import requests

# Download the model weights from Hugging Face Hub
model_path = hf_hub_download(repo_id="izeeek/resnet18_pneumonia_classifier", filename="resnet18_pneumonia_classifier.pth")

# Load the model architecture (ResNet18)
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)

# Adjust the final layer for binary classification (if necessary)
model.fc = torch.nn.Linear(model.fc.in_features, 2)

# Load the downloaded weights
model.load_state_dict(torch.load(model_path))

# Set the model to evaluation mode
model.eval()

# Image preprocessing
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Sample Image (replace with your own image URL)
url = 'https://storage.googleapis.com/kagglesdsdata/datasets/17810/23812/chest_xray/test/NORMAL/IM-0005-0001.jpeg?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=databundle-worker-v2%40kaggle-161607.iam.gserviceaccount.com%2F20240913%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20240913T014624Z&X-Goog-Expires=345600&X-Goog-SignedHeaders=host&X-Goog-Signature=1f6b37d181f12d083ffc951657e85fea087bb4e81ab955ec955dafcdae49c0d53ce20bc0be93605e2672b9bdd59e752eba9d5a3a0da2e3b3a03c888580b88d63d87611b4e4cec8b8802d53abd53fda165dd04765b8d9f30ddd4e908cd7a2a389ce8244fca7bfa36b3c9cff79d7c5e3f9ee7d59d5b9ef97a2e5c083997892ee3023302313fafff48ded58232db57d6affcfaee704eebba55f2b0abac40b14a38137275ad19cdb1b787930d134f7c30710e29c409bd765ca02e46851470a871cc697f614d464086373f43f5462f241eaf023cfd31e217d7b11e24e1ff34857deb200f5dc1a8c28c8115048ee840be8481f1bd79a2d8e2de1b30cb71420c007d32c'
img = Image.open(requests.get(url, stream=True).raw)

# Preprocess the image
input_img = transform(img).unsqueeze(0)

# Inference
with torch.no_grad():
    output = model(input_img)
    _, predicted = torch.max(output, 1)

# Labels for classification
labels = {0: 'Pneumonia', 1: 'Normal'}
print(f'Predicted label: {labels[predicted.item()]}')
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The HF Inference API does not support image-classification models for pytorch library.