nielsr HF staff commited on
Commit
c629061
1 Parent(s): 13c2a91

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +8 -3
README.md CHANGED
@@ -33,16 +33,21 @@ fine-tuned versions on a task that interests you.
33
  Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes:
34
 
35
  ```python
36
- from transformers import BeitFeatureExtractor, BeitForImageClassification
37
  from PIL import Image
38
  import requests
 
39
  url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
40
  image = Image.open(requests.get(url, stream=True).raw)
41
- feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-large-patch16-224-pt22k-ft22k')
 
42
  model = BeitForImageClassification.from_pretrained('microsoft/beit-large-patch16-224-pt22k-ft22k')
43
- inputs = feature_extractor(images=image, return_tensors="pt")
 
 
44
  outputs = model(**inputs)
45
  logits = outputs.logits
 
46
  # model predicts one of the 21,841 ImageNet-22k classes
47
  predicted_class_idx = logits.argmax(-1).item()
48
  print("Predicted class:", model.config.id2label[predicted_class_idx])
33
  Here is how to use this model to classify an image of the COCO 2017 dataset into one of the 1,000 ImageNet classes:
34
 
35
  ```python
36
+ from transformers import BeitImageProcessor, BeitForImageClassification
37
  from PIL import Image
38
  import requests
39
+
40
  url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
41
  image = Image.open(requests.get(url, stream=True).raw)
42
+
43
+ processor = BeitImageProcessor.from_pretrained('microsoft/beit-large-patch16-224-pt22k-ft22k')
44
  model = BeitForImageClassification.from_pretrained('microsoft/beit-large-patch16-224-pt22k-ft22k')
45
+
46
+ inputs = processor(images=image, return_tensors="pt")
47
+
48
  outputs = model(**inputs)
49
  logits = outputs.logits
50
+
51
  # model predicts one of the 21,841 ImageNet-22k classes
52
  predicted_class_idx = logits.argmax(-1).item()
53
  print("Predicted class:", model.config.id2label[predicted_class_idx])