g8a9 commited on
Commit
dc1d715
1 Parent(s): aa345fa
Files changed (1) hide show
  1. app.py +29 -4
app.py CHANGED
@@ -59,7 +59,7 @@ def get_image_features(model, image_dir):
59
 
60
  loader = torch.utils.data.DataLoader(
61
  dataset,
62
- batch_size=32,
63
  shuffle=False,
64
  num_workers=4,
65
  drop_last=False,
@@ -103,7 +103,8 @@ def text_encoder(text, tokenizer):
103
  return jnp.expand_dims(embedding, axis=0)
104
 
105
 
106
- def precompute_image_features(loader):
 
107
  image_features = []
108
  for i, (images) in enumerate(tqdm(loader)):
109
  images = images.permute(0, 2, 3, 1).numpy()
@@ -145,8 +146,32 @@ if query:
145
  "dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
146
  )
147
 
148
- image_features, dataset = get_image_features(model, "photos")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- image_paths = find_image(query, dataset, tokenizer, image_features, n=3)
151
 
152
  st.image(image_paths)
 
59
 
60
  loader = torch.utils.data.DataLoader(
61
  dataset,
62
+ batch_size=16,
63
  shuffle=False,
64
  num_workers=4,
65
  drop_last=False,
 
103
  return jnp.expand_dims(embedding, axis=0)
104
 
105
 
106
+ @st.cache
107
+ def precompute_image_features(model, loader):
108
  image_features = []
109
  for i, (images) in enumerate(tqdm(loader)):
110
  images = images.permute(0, 2, 3, 1).numpy()
 
146
  "dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
147
  )
148
 
149
+ image_size = model.config.vision_config.image_size
150
+
151
+ val_preprocess = transforms.Compose(
152
+ [
153
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
154
+ CenterCrop(image_size),
155
+ ToTensor(),
156
+ Normalize(
157
+ (0.48145466, 0.4578275, 0.40821073),
158
+ (0.26862954, 0.26130258, 0.27577711),
159
+ ),
160
+ ]
161
+ )
162
+
163
+ dataset = CustomDataSet("photos/", transform=val_preprocess)
164
+
165
+ loader = torch.utils.data.DataLoader(
166
+ dataset,
167
+ batch_size=16,
168
+ shuffle=False,
169
+ num_workers=2,
170
+ drop_last=False,
171
+ )
172
+
173
+ image_features = precompute_image_features(model, loader)
174
 
175
+ image_paths = find_image(query, dataset, tokenizer, image_features, n=2)
176
 
177
  st.image(image_paths)