Edit model card

Model Card: ResNet18 for Fruit Classification

Introduction

This model card provides information about a ResNet18-based fruit classification model trained using the provided code.

Model Details

  • Model Architecture: ResNet18
  • Number of Classes: 33 (Fruit Categories)
  • Loss Function: Cross Entropy Loss
  • Optimizer: Stochastic Gradient Descent (SGD) with learning rate of 0.001
  • Training Duration: 5 epochs

Training and Evaluation

The model was trained and evaluated using the specified training and test datasets. During training, the model's performance was assessed based on training loss, training accuracy, test loss, and test accuracy for each epoch.

Performance

The performance metrics for each epoch are as follows:

  • Epoch 1/5

    • Train Loss: 2.2403
    • Train Accuracy: 70.68%
    • Test Loss: 0.2475
    • Test Accuracy: 99.11%
  • Epoch 2/5

    • Train Loss: 0.1282
    • Train Accuracy: 99.65%
    • Test Loss: 0.0771
    • Test Accuracy: 99.82%
  • Epoch 3/5

    • Train Loss: 0.0568
    • Train Accuracy: 99.89%
    • Test Loss: 0.0514
    • Test Accuracy: 99.76%
  • Epoch 4/5

    • Train Loss: 0.0347
    • Train Accuracy: 99.96%
    • Test Loss: 0.0332
    • Test Accuracy: 99.91%
  • Epoch 5/5

    • Train Loss: 0.0247
    • Train Accuracy: 99.97%
    • Test Loss: 0.0240
    • Test Accuracy: 99.94%

Usage

To use this model for fruit classification, load the trained weights and utilize the model to classify fruit images into one of the 33 fruit categories.

# Load the trained weights
model = resnet18()
model_weights_path = 'resnet18_fruit_classifier.pth' 
model.load_state_dict(torch.load(model_weights_path))
model.eval()
transform = transforms.Compose([
    transforms.Resize(255),
    transforms.ToTensor()
])

# Perform inference on a sample image
# ... (code to preprocess and load an image)

# Forward pass to get predictions
with torch.no_grad():
    output = model(image)

# Process the output to get predicted class
predicted_class = torch.argmax(output, dim=1)
print("Predicted Class:", predicted_class.item())

Limitations and Considerations

  • The model's performance may vary based on the quality and diversity of the dataset used for training.
  • The provided number of epochs (5) for training may not be sufficient for achieving optimal performance. Further fine-tuning and experimentation might be necessary.
  • Additional data augmentation and regularization techniques could potentially improve the model's robustness and accuracy.

Ethical Considerations

Ensure that the dataset used for training is collected and used ethically, respecting privacy, consent, and applicable laws and regulations.

Disclaimer

This model card is for illustrative purposes and does not guarantee any specific performance or outcomes when using the provided code. Users are encouraged to conduct thorough evaluation and testing for their specific use cases.

Downloads last month
0
Unable to determine this model's library. Check the docs .