LayBraid commited on
Commit
3432a10
1 Parent(s): c78b520

update app

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -8,25 +8,31 @@ from transformers import CLIPProcessor, CLIPModel
8
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
9
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
10
 
11
- cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
12
 
13
- text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes])
14
- text_inputs_2 = ["a photo of a dog", "a photo of a cat"]
15
 
 
 
 
 
16
 
17
- # TODO debug cette ligne pour avoir un affichage correct
18
 
19
-
20
- # TODO Finir l'affichage du résultat
21
 
22
 
23
  def send_inputs(img):
24
- inputs = processor(text=text_inputs_2, images=img, return_tensors="pt", padding=True)
 
25
  outputs = model(**inputs)
26
  logits_per_image = outputs.logits_per_image
27
  probs = logits_per_image.softmax(dim=1)
28
- print(probs)
29
- return probs
 
 
 
30
 
31
 
32
  if __name__ == "__main__":
 
8
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
9
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
10
 
11
+ cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=False, train=False)
12
 
13
+ text_inputs = []
 
14
 
15
+ for c in cifar100.classes:
16
+ classes = "a photo of a " + c
17
+ print(classes)
18
+ text_inputs.append(classes)
19
 
20
+ print(text_inputs)
21
 
22
+ test = ["a photo of a dog", "a photo of a cat"]
 
23
 
24
 
25
  def send_inputs(img):
26
+ inputs = processor(text=test, images=img, return_tensors="pt", padding=True)
27
+
28
  outputs = model(**inputs)
29
  logits_per_image = outputs.logits_per_image
30
  probs = logits_per_image.softmax(dim=1)
31
+
32
+ result = probs.argmax(dim=1)
33
+ index = result.item()
34
+ print(test[index])
35
+ return test[index]
36
 
37
 
38
  if __name__ == "__main__":