Specify the head of the model in the model card (#2)
Browse files- Specify the head of the model in the model card (f83e01636e1335dd07330eaa06a99a5456333d8e)
- Update code snippet (c9ee1c31e23054aeb94b8bd83036b52e9f211e99)
README.md
CHANGED
@@ -17,7 +17,7 @@ widget:
|
|
17 |
|
18 |
# ONNX convert of ViT (base-sized model)
|
19 |
|
20 |
-
|
21 |
|
22 |
# Vision Transformer (base-sized model)
|
23 |
|
@@ -43,25 +43,22 @@ fine-tuned versions on a task that interests you.
|
|
43 |
Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes:
|
44 |
|
45 |
```python
|
46 |
-
from transformers import
|
47 |
-
from
|
48 |
-
import
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
|
54 |
-
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
|
55 |
-
|
56 |
-
inputs = feature_extractor(images=image, return_tensors="pt")
|
57 |
-
outputs = model(**inputs)
|
58 |
-
logits = outputs.logits
|
59 |
-
# model predicts one of the 1000 ImageNet classes
|
60 |
-
predicted_class_idx = logits.argmax(-1).item()
|
61 |
-
print("Predicted class:", model.config.id2label[predicted_class_idx])
|
62 |
-
```
|
63 |
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
## Training data
|
67 |
|
|
|
17 |
|
18 |
# ONNX convert of ViT (base-sized model)
|
19 |
|
20 |
+
Conversion of [ViT-base](https://huggingface.co/google/vit-base-patch16-224), which has a classification head to perform **image classification**.
|
21 |
|
22 |
# Vision Transformer (base-sized model)
|
23 |
|
|
|
43 |
Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes:
|
44 |
|
45 |
```python
|
46 |
+
from transformers import AutoFeatureExtractor
|
47 |
+
from optimum.onnxruntime import ORTModelForImageClassification
|
48 |
+
from optimum.pipelines import pipeline
|
49 |
+
|
50 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("optimum/vit-base-patch16-224")
|
51 |
+
# Loading already converted and optimized ORT checkpoint for inference
|
52 |
+
model = ORTModelForImageClassification.from_pretrained("optimum/vit-base-patch16-224")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
+
onnx_img_classif = pipeline(
|
55 |
+
"image-classification", model=model, feature_extractor=feature_extractor
|
56 |
+
)
|
57 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
58 |
+
|
59 |
+
pred = onnx_img_classif(url)
|
60 |
+
print("Top-5 predicted classes:", pred)
|
61 |
+
```
|
62 |
|
63 |
## Training data
|
64 |
|