mnist-classifier / README.md
jerilseb's picture
Update README.md
58ed48f verified
|
raw
history blame
680 Bytes
---
library_name: transformers
tags: []
---
# Usage
Register the model
```python
from transformers import AutoConfig, AutoModel
AutoConfig.register("mnist_classifier", MNISTConfig)
AutoModel.register(MNISTConfig, MNISTClassifier)
```
Inference
```python
from transformers import AutoConfig, AutoModel
import torch
config = AutoConfig.from_pretrained("jerilseb/mnist-classifier")
model = AutoModel.from_pretrained("jerilseb/mnist-classifier")
input_tensor = torch.randn(1, 28, 28) # Single image, adjust batch size as needed
with torch.no_grad():
output = model(input_tensor)
predicted_class = output.argmax(-1).item()
print(f"Predicted class: {predicted_class}")
```