g8a9 commited on
Commit
7369efb
1 Parent(s): dc1d715

Add static features

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. app.py +11 -97
  3. static/features/features.npy +3 -0
  4. utils.py +69 -0
.gitattributes CHANGED
@@ -14,3 +14,4 @@
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
 
 
14
  *.pb filter=lfs diff=lfs merge=lfs -text
15
  *.pt filter=lfs diff=lfs merge=lfs -text
16
  *.pth filter=lfs diff=lfs merge=lfs -text
17
+ *.npy filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -16,6 +16,8 @@ from torchvision.transforms.functional import InterpolationMode
16
  from tqdm import tqdm
17
  from modeling_hybrid_clip import FlaxHybridCLIP
18
 
 
 
19
 
20
  @st.cache
21
  def get_model():
@@ -39,92 +41,9 @@ def download_images():
39
  print("Done.")
40
 
41
 
42
- @st.cache(allow_output_mutation=True)
43
- def get_image_features(model, image_dir):
44
- image_size = model.config.vision_config.image_size
45
-
46
- val_preprocess = transforms.Compose(
47
- [
48
- Resize([image_size], interpolation=InterpolationMode.BICUBIC),
49
- CenterCrop(image_size),
50
- ToTensor(),
51
- Normalize(
52
- (0.48145466, 0.4578275, 0.40821073),
53
- (0.26862954, 0.26130258, 0.27577711),
54
- ),
55
- ]
56
- )
57
-
58
- dataset = CustomDataSet(image_dir, transform=val_preprocess)
59
-
60
- loader = torch.utils.data.DataLoader(
61
- dataset,
62
- batch_size=16,
63
- shuffle=False,
64
- num_workers=4,
65
- drop_last=False,
66
- )
67
-
68
- return precompute_image_features(loader), dataset
69
-
70
-
71
- class CustomDataSet(torch.utils.data.Dataset):
72
- def __init__(self, main_dir, transform):
73
- self.main_dir = main_dir
74
- self.transform = transform
75
- all_imgs = os.listdir(main_dir)
76
- self.total_imgs = natsort.natsorted(all_imgs)
77
-
78
- def __len__(self):
79
- return len(self.total_imgs)
80
-
81
- def get_image_name(self, idx):
82
- return self.total_imgs[idx]
83
-
84
- def __getitem__(self, idx):
85
- img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
86
- image = PilImage.open(img_loc).convert("RGB")
87
- tensor_image = self.transform(image)
88
- return tensor_image
89
-
90
-
91
- def text_encoder(text, tokenizer):
92
- inputs = tokenizer(
93
- [text],
94
- max_length=96,
95
- truncation=True,
96
- padding="max_length",
97
- return_tensors="np",
98
- )
99
- embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[
100
- 0
101
- ]
102
- embedding /= jnp.linalg.norm(embedding)
103
- return jnp.expand_dims(embedding, axis=0)
104
-
105
-
106
- @st.cache
107
- def precompute_image_features(model, loader):
108
- image_features = []
109
- for i, (images) in enumerate(tqdm(loader)):
110
- images = images.permute(0, 2, 3, 1).numpy()
111
- features = model.get_image_features(
112
- images,
113
- )
114
- features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
115
- image_features.extend(features)
116
- return jnp.array(image_features)
117
-
118
-
119
- def find_image(text_query, dataset, tokenizer, image_features, n=1):
120
- zeroshot_weights = text_encoder(text_query, tokenizer)
121
- zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
122
- distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
123
- file_paths = []
124
- for i in range(1, n + 1):
125
- idx = jnp.argsort(distances, axis=0)[-i, 0]
126
- file_paths.append("photos/" + dataset.get_image_name(idx))
127
- return file_paths
128
 
129
 
130
  """
@@ -142,6 +61,9 @@ if query:
142
  model = get_model()
143
  download_images()
144
 
 
 
 
145
  tokenizer = AutoTokenizer.from_pretrained(
146
  "dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
147
  )
@@ -160,18 +82,10 @@ if query:
160
  ]
161
  )
162
 
163
- dataset = CustomDataSet("photos/", transform=val_preprocess)
164
 
165
- loader = torch.utils.data.DataLoader(
166
- dataset,
167
- batch_size=16,
168
- shuffle=False,
169
- num_workers=2,
170
- drop_last=False,
171
  )
172
 
173
- image_features = precompute_image_features(model, loader)
174
-
175
- image_paths = find_image(query, dataset, tokenizer, image_features, n=2)
176
-
177
  st.image(image_paths)
 
16
  from tqdm import tqdm
17
  from modeling_hybrid_clip import FlaxHybridCLIP
18
 
19
+ import utils
20
+
21
 
22
  @st.cache
23
  def get_model():
 
41
  print("Done.")
42
 
43
 
44
+ @st.cache()
45
+ def get_image_features():
46
+ return jnp.load("static/features/features.npy")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  """
 
61
  model = get_model()
62
  download_images()
63
 
64
+ image_features = get_image_features()
65
+
66
+ model = get_model()
67
  tokenizer = AutoTokenizer.from_pretrained(
68
  "dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
69
  )
 
82
  ]
83
  )
84
 
85
+ dataset = utils.CustomDataSet("photos/", transform=val_preprocess)
86
 
87
+ image_paths = utils.find_image(
88
+ query, model, dataset, tokenizer, image_features, n=2
 
 
 
 
89
  )
90
 
 
 
 
 
91
  st.image(image_paths)
static/features/features.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53a956386a27089b0bfe84bc311fbee885983815f5a6e9d9e58ec5c3a52015e9
3
+ size 51191936
utils.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from transformers import AutoTokenizer
4
+ from jax import numpy as jnp
5
+ import json
6
+ import requests
7
+ import zipfile
8
+ import io
9
+ import natsort
10
+ from PIL import Image as PilImage
11
+ from tqdm import tqdm
12
+
13
+
14
+ class CustomDataSet(torch.utils.data.Dataset):
15
+ def __init__(self, main_dir, transform):
16
+ self.main_dir = main_dir
17
+ self.transform = transform
18
+ all_imgs = os.listdir(main_dir)
19
+ self.total_imgs = natsort.natsorted(all_imgs)
20
+
21
+ def __len__(self):
22
+ return len(self.total_imgs)
23
+
24
+ def get_image_name(self, idx):
25
+ return self.total_imgs[idx]
26
+
27
+ def __getitem__(self, idx):
28
+ img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
29
+ image = PilImage.open(img_loc).convert("RGB")
30
+ tensor_image = self.transform(image)
31
+ return tensor_image
32
+
33
+
34
+ def text_encoder(text, model, tokenizer):
35
+ inputs = tokenizer(
36
+ [text],
37
+ max_length=96,
38
+ truncation=True,
39
+ padding="max_length",
40
+ return_tensors="np",
41
+ )
42
+ embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[
43
+ 0
44
+ ]
45
+ embedding /= jnp.linalg.norm(embedding)
46
+ return jnp.expand_dims(embedding, axis=0)
47
+
48
+
49
+ def precompute_image_features(model, loader):
50
+ image_features = []
51
+ for i, (images) in enumerate(tqdm(loader)):
52
+ images = images.permute(0, 2, 3, 1).numpy()
53
+ features = model.get_image_features(
54
+ images,
55
+ )
56
+ features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
57
+ image_features.extend(features)
58
+ return jnp.array(image_features)
59
+
60
+
61
+ def find_image(text_query, model, dataset, tokenizer, image_features, n=1):
62
+ zeroshot_weights = text_encoder(text_query, model, tokenizer)
63
+ zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
64
+ distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
65
+ file_paths = []
66
+ for i in range(1, n + 1):
67
+ idx = jnp.argsort(distances, axis=0)[-i, 0]
68
+ file_paths.append("photos/" + dataset.get_image_name(idx))
69
+ return file_paths