ioanasong commited on
Commit
5dd71fd
1 Parent(s): e0724b2

added webcam file for huggingface

Browse files
Files changed (1) hide show
  1. app.py +55 -1
app.py CHANGED
@@ -1,3 +1,57 @@
 
 
 
1
  import gradio as gr
 
2
 
3
- gr.load("models/ioanasong/vit-MINC-2500").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import ViTFeatureExtractor, ViTForImageClassification
2
+ from PIL import Image
3
+ import torch
4
  import gradio as gr
5
+ from torch.nn import functional as F
6
 
7
+
8
+ # gr.load("models/ioanasong/vit-MINC-2500").launch()
9
+
10
+
11
+ # Load the pre-trained ViT model and feature extractor
12
+ model_name = "ioanasong/vit-MINC-2500"
13
+ model = ViTForImageClassification.from_pretrained(model_name)
14
+ model.eval()
15
+
16
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
17
+
18
+
19
+ # Define the prediction function
20
+ # def predict(image):
21
+
22
+ # print(image)
23
+ # # Preprocess the image
24
+ # inputs = feature_extractor(images=image, return_tensors="pt")
25
+ # # Make prediction
26
+ # with torch.no_grad():
27
+ # outputs = model(**inputs)
28
+ # logits = outputs.logits
29
+ # # Get predicted label
30
+ # predicted_class_idx = logits.argmax(-1).item()
31
+ # predicted_label = model.config.id2label[predicted_class_idx]
32
+ # return predicted_label
33
+ def predict(image):
34
+ # Preprocess the image using the feature extractor
35
+ inputs = feature_extractor(images=image, return_tensors="pt")
36
+ # Make prediction using the model
37
+ with torch.no_grad():
38
+ outputs = model(**inputs)
39
+ logits = outputs.logits
40
+ # Compute softmax probabilities
41
+ probs = F.softmax(logits, dim=-1)[0]
42
+ # Create a dictionary of label and probability
43
+ prob_dict = {model.config.id2label[i]: prob.item() for i, prob in enumerate(probs)}
44
+ return prob_dict
45
+
46
+
47
+ # Create the Gradio interface
48
+ iface = gr.Interface(
49
+ fn=predict,
50
+ inputs=gr.Image(sources=['webcam'], streaming = True),
51
+ # outputs=gr.Label(num_top_classes=len(model.config.id2label)),
52
+ outputs=gr.Label(num_top_classes=5),
53
+ title="ViT Image Classification",
54
+ description="Capture an image from the camera and classify it using a pre-trained Vision Transformer (ViT) model.",
55
+ )
56
+ # Launch the app
57
+ iface.launch()