abidlabs HF Staff commited on
Commit
19d68bf
·
1 Parent(s): 164dd9b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -1
app.py CHANGED
@@ -1,3 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- gr.Interface.load("huggingface/google/vit-base-patch16-224").launch()
 
1
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
2
+ from PIL import Image
3
+ import torch.nn.functional as F
4
+
5
+
6
+ feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
7
+ model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
8
+
9
+ def predict(image):
10
+ inputs = feature_extractor(images=image, return_tensors="pt")
11
+ outputs = model(**inputs)
12
+ logits = outputs.logits
13
+ predicted_class_prob = F.softmax(logits, dim=-1).detach().numpy().max()
14
+ predicted_class_idx = logits.argmax(-1).item()
15
+ label = model.config.id2label[predicted_class_idx].split(",")[0]
16
+ return {label: float(predicted_class_prob)}
17
+
18
  import gradio as gr
19
 
20
+ gr.Interface(predict, gr.Image(type="pil"), "label").launch()