DGurgurov commited on
Commit
ad2f296
1 Parent(s): cebbd1f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -14
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import gradio as gr
2
- import torch
3
  from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
4
  from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
5
  from PIL import Image
6
- import requests
7
  from datasets import load_dataset
8
 
9
  # Load your fine-tuned model and dataset
@@ -17,14 +15,6 @@ labels = list(set(dataset['train']['label']))
17
  label2id = {label: i for i, label in enumerate(labels)}
18
  id2label = {i: label for label, i in label2id.items()}
19
 
20
- # Define transformations for input images
21
- transform = Compose([
22
- Resize((224, 224)),
23
- CenterCrop(224),
24
- ToTensor(),
25
- Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26
- ])
27
-
28
  # Function to classify image using CLIP model
29
  def classify_image(image):
30
  # Preprocess the image
@@ -34,13 +24,18 @@ def classify_image(image):
34
  # Run inference
35
  outputs = model(**inputs)
36
 
37
- # Get predicted label
38
- predicted_label_id = torch.argmax(outputs, dim=1).item()
39
- print(predicted_label_id)
 
 
 
 
40
  predicted_label = id2label[predicted_label_id]
41
 
42
  return predicted_label
43
 
 
44
  # Gradio interface
45
  iface = gr.Interface(
46
  fn=classify_image,
@@ -51,4 +46,4 @@ iface = gr.Interface(
51
  )
52
 
53
  # Launch the Gradio interface
54
- iface.launch()
 
1
  import gradio as gr
 
2
  from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
3
  from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
4
  from PIL import Image
 
5
  from datasets import load_dataset
6
 
7
  # Load your fine-tuned model and dataset
 
15
  label2id = {label: i for i, label in enumerate(labels)}
16
  id2label = {i: label for label, i in label2id.items()}
17
 
 
 
 
 
 
 
 
 
18
  # Function to classify image using CLIP model
19
  def classify_image(image):
20
  # Preprocess the image
 
24
  # Run inference
25
  outputs = model(**inputs)
26
 
27
+ # Extract logits and apply softmax
28
+ logits_per_image = outputs.logits_per_image # logits_per_image is a tensor with shape [1, num_labels]
29
+ probs = logits_per_image[0].softmax(dim=0) # Take the softmax across the labels
30
+
31
+ # Get predicted label id and score
32
+ predicted_label_id = probs.argmax().item()
33
+
34
  predicted_label = id2label[predicted_label_id]
35
 
36
  return predicted_label
37
 
38
+
39
  # Gradio interface
40
  iface = gr.Interface(
41
  fn=classify_image,
 
46
  )
47
 
48
  # Launch the Gradio interface
49
+ iface.launch()