oschamp commited on
Commit
062fa6d
1 Parent(s): 20feb73

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +23 -0
README.md CHANGED
@@ -83,3 +83,26 @@ The following hyperparameters were used during training:
83
  - Pytorch 1.13.1+cu117
84
  - Datasets 2.9.0
85
  - Tokenizers 0.13.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  - Pytorch 1.13.1+cu117
84
  - Datasets 2.9.0
85
  - Tokenizers 0.13.2
86
+
87
+ ### Code to Run
88
+ def vit_classify(image):
89
+ vit = ViTForImageClassification.from_pretrained("oschamp/vit-artworkclassifier")
90
+ vit.eval()
91
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92
+ vit.to(device)
93
+
94
+ model_name_or_path = 'google/vit-base-patch16-224-in21k'
95
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
96
+
97
+ #LOAD IMAGE
98
+
99
+ encoding = feature_extractor(images=image, return_tensors="pt")
100
+ encoding.keys()
101
+
102
+ pixel_values = encoding['pixel_values'].to(device)
103
+
104
+ outputs = vit(pixel_values)
105
+ logits = outputs.logits
106
+
107
+ prediction = logits.argmax(-1)
108
+ return prediction.item() #vit.config.id2label[prediction.item()]