abidlabs HF Staff commited on
Commit
5fb3af5
·
1 Parent(s): 16c1617

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -1,16 +1,18 @@
1
  from transformers import ViTFeatureExtractor, ViTForImageClassification
2
  from PIL import Image
 
3
  import torch.nn.functional as F
4
  import time
5
 
 
6
 
7
- feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
8
- model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
9
 
10
  def predict(image):
11
  inputs = feature_extractor(images=image, return_tensors="pt")
12
- outputs = model(**inputs)
13
- logits = outputs.logits
14
  predicted_class_prob = F.softmax(logits, dim=-1).detach().numpy().max()
15
  predicted_class_idx = logits.argmax(-1).item()
16
  label = model.config.id2label[predicted_class_idx].split(",")[0]
 
1
  from transformers import ViTFeatureExtractor, ViTForImageClassification
2
  from PIL import Image
3
+ import torch
4
  import torch.nn.functional as F
5
  import time
6
 
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
 
9
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224').to(device)
10
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device)
11
 
12
  def predict(image):
13
  inputs = feature_extractor(images=image, return_tensors="pt")
14
+ outputs = model(**inputs).to(device)
15
+ logits = outputs.logits.to(device)
16
  predicted_class_prob = F.softmax(logits, dim=-1).detach().numpy().max()
17
  predicted_class_idx = logits.argmax(-1).item()
18
  label = model.config.id2label[predicted_class_idx].split(",")[0]