nikolaybi73 commited on
Commit
78369ea
1 Parent(s): 677c104
Files changed (1) hide show
  1. rawresnet.py +18 -0
rawresnet.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoImageProcessor, ResNetForImageClassification
2
+ import torch
3
+ from datasets import load_dataset
4
+
5
+ dataset = load_dataset("huggingface/cats-image")
6
+ image = dataset["test"]["image"][0]
7
+
8
+ processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50")
9
+ model = ResNetForImageClassification.from_pretrained("microsoft/resnet-50")
10
+
11
+ inputs = processor(image, return_tensors="pt")
12
+
13
+ with torch.no_grad():
14
+ logits = model(**inputs).logits
15
+
16
+ # model predicts one of the 1000 ImageNet classes
17
+ predicted_label = logits.argmax(-1).item()
18
+ print(model.config.id2label[predicted_label])