LayBraid commited on
Commit
998ea00
1 Parent(s): 21884ee

add requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +5 -21
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,8 +1,7 @@
 
1
  import gradio as gr
2
  import os
3
  import torch
4
- from torchvision import transforms
5
- from PIL import Image
6
  from torchvision.datasets import CIFAR100
7
  from transformers import CLIPProcessor, CLIPModel
8
 
@@ -11,32 +10,17 @@ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
11
 
12
  cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
13
 
14
- IMG_SIZE = 32 if torch.cuda.is_available() else 32
15
- COMPOSED_TRANSFORMERS = transforms.Compose([
16
- transforms.Resize(IMG_SIZE),
17
- transforms.ToTensor(),
18
- ])
19
 
20
- NORMALIZE_TENSOR = transforms.Normalize(
21
- mean=[0.485, 0.456, 0.406],
22
- std=[0.229, 0.224, 0.225]
23
- )
24
 
 
25
 
26
- def np_array_to_tensor_image(img, width=IMG_SIZE, height=IMG_SIZE, device='cpu'):
27
- image = Image.fromarray(img).convert('RGB').resize((width, height))
28
- image = COMPOSED_TRANSFORMERS(image).unsqueeze(0)
29
- return image.to(device, torch.float)
30
 
31
-
32
- def normalize_tensor(tensor: torch.tensor) -> torch.tensor:
33
- return NORMALIZE_TENSOR(tensor)
34
 
35
 
36
  def send_inputs(img):
37
- ##img = np_array_to_tensor_image(img)
38
- ##img = normalize_tensor(img)
39
- inputs = processor(images=img, return_tensors="pt", padding=True)
40
  outputs = model(**inputs)
41
  logits_per_image = outputs.logits_per_image
42
  probs = logits_per_image.softmax(dim=1)
 
1
+ import clip
2
  import gradio as gr
3
  import os
4
  import torch
 
 
5
  from torchvision.datasets import CIFAR100
6
  from transformers import CLIPProcessor, CLIPModel
7
 
 
10
 
11
  cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
12
 
13
+ text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes])
 
 
 
 
14
 
 
 
 
 
15
 
16
+ # TODO debug cette ligne pour avoir un affichage correct
17
 
 
 
 
 
18
 
19
+ # TODO Finir l'affichage du résultat
 
 
20
 
21
 
22
  def send_inputs(img):
23
+ inputs = processor(text=cifar100.classes, images=img, return_tensors="pt", padding=True)
 
 
24
  outputs = model(**inputs)
25
  logits_per_image = outputs.logits_per_image
26
  probs = logits_per_image.softmax(dim=1)
requirements.txt CHANGED
@@ -2,4 +2,5 @@ torch~=1.11.0
2
  torchvision~=0.12.0
3
  gradio~=3.0.2
4
  Pillow~=9.0.1
5
- transformers~=4.19.4
 
 
2
  torchvision~=0.12.0
3
  gradio~=3.0.2
4
  Pillow~=9.0.1
5
+ transformers~=4.19.4
6
+ clip~=0.2.0