|
--- |
|
library_name: transformers |
|
tags: [] |
|
--- |
|
|
|
# Usage |
|
|
|
Register the model |
|
|
|
```python |
|
from transformers import AutoConfig, AutoModel |
|
|
|
AutoConfig.register("mnist_classifier", MNISTConfig) |
|
AutoModel.register(MNISTConfig, MNISTClassifier) |
|
``` |
|
Inference |
|
```python |
|
from transformers import AutoConfig, AutoModel |
|
import torch |
|
|
|
config = AutoConfig.from_pretrained("jerilseb/mnist-classifier") |
|
model = AutoModel.from_pretrained("jerilseb/mnist-classifier") |
|
|
|
input_tensor = torch.randn(1, 28, 28) # Single image, adjust batch size as needed |
|
|
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
|
|
predicted_class = output.argmax(-1).item() |
|
print(f"Predicted class: {predicted_class}") |
|
``` |
|
|