Image Classification
PyTorch
torch
resnet
diagrams
computer-vision

Model Card for Diagram Classification Model

Model Details

Model Description

This is a fine-tuned ResNet-18 model trained for binary image classification, distinguishing between diagrams and non-diagrams. The model is designed for use in applications that need automatic filtering or processing of diagram-based content.

  • Developed by: Aya Mohamed
  • Model type: ResNet-18 (Fine-tuned for image classification)
  • Language(s) (NLP): Not applicable (Computer Vision model)
  • License: Apache 2.0
  • Finetuned from model: microsoft/resnet-18

Model Sources

Uses

Direct Use

This model is intended for classifying images as diagrams or non-diagrams. It can be used in:

  • Document processing (extracting diagrams from PDFs or scanned documents)
  • Chart-based visual question generation (VQG)
  • Content moderation (filtering diagram images from general image datasets)

Out-of-Scope Use

  • Not suitable for multi-class classification beyond diagrams vs. non-diagrams.
  • Not designed for hand-drawn sketches or complex figures with mixed elements.

Bias, Risks, and Limitations

  • The model's accuracy depends on the training dataset, which may not cover all possible diagram styles.
  • May misclassify charts, blueprints, or artistic drawings if they resemble diagrams.

Recommendations

Users should evaluate the model on their specific dataset before deployment to ensure it performs well in their context.

πŸš€ How to Use

1️⃣ Load the Model from Hugging Face

You can download the model and load it using torch.

import torch
from huggingface_hub import hf_hub_download

# Download model from Hugging Face Hub
model_path = hf_hub_download(repo_id="Ayamohamed/DiaClassification", filename="model.pth")

# Load model
model_hg = torch.load(model_path)
model_hg.eval()  # Set to evaluation mode

2️⃣ Preprocess and Classify an Image

from PIL import Image
from torchvision import transforms

# Define Image Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(image_path):
    image = Image.open(image_path).convert("RGB")  
    image = transform(image).unsqueeze(0) 
    with torch.no_grad():
        output = model_hg(image)
        class_idx = torch.argmax(output, dim=1).item()

    return "Diagram" if class_idx == 0 else "Not Diagram"

# Example usage
print(predict("my-diagram-classifier/31188_1536932698.jpg"))

Training Details

Training Data

The model was trained using:

  • ChartQA dataset (for diagram samples)
  • JasmineQiuqiu/diagrams_with_captions_2 (for diagram samples)
  • COCO dataset (subset) (for non-diagram samples)

Training Procedure

  • Pretrained model: microsoft/resnet-18
  • Optimization: Adam optimizer
  • Loss function: Cross-entropy loss
  • Training duration: Approx. X hours on an NVIDIA GPU

Evaluation

Testing Data & Metrics

  • Dataset: Held-out test set from ChartQA, AI2D-RST, and COCO
  • Metrics:
    • Test Loss: 0.0371
    • Test Accuracy: 99.08%
    • Precision: 0.9995
    • Recall: 0.9820
    • F1 Score: 0.9907

Environmental Impact

  • Hardware Used: NVIDIA A100 GPU
  • Compute Hours: Approx. X hours
  • Estimated Carbon Emission: Use MLCO2 Calculator

Citation

If you use this model, please cite:

@misc{aya2025diaclass,
  author = {Aya Mohamed},
  title = {Diagram Classification Model},
  year = {2025},
  publisher = {Hugging Face},
  url = {https://huggingface.co/Ayamohamed/diaclass-model}
}
Downloads last month
4
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for Ayamohamed/DiaClassModel

Finetuned
(26)
this model

Datasets used to train Ayamohamed/DiaClassModel

Space using Ayamohamed/DiaClassModel 1