g8a9 commited on
Commit
128a4c1
1 Parent(s): 7264c9b

Replicate IR on Unsplash with local download

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app.py +139 -4
  3. requirements.txt +6 -3
.gitignore CHANGED
@@ -1 +1,2 @@
1
- __pycache__
 
1
+ __pycache__
2
+ photos
app.py CHANGED
@@ -1,4 +1,19 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from modeling_hybrid_clip import FlaxHybridCLIP
3
 
4
 
@@ -7,12 +22,132 @@ def get_model():
7
  return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  """
11
  # CLIP Italian Demo (Flax Community Week)
12
  """
13
 
14
- x = st.slider("Select a value")
15
- st.write(x, "squared is", x * x)
16
 
17
- model = get_model()
18
- st.write(str(model.config))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import os
3
+ import torch
4
+ from transformers import AutoTokenizer
5
+ from jax import numpy as jnp
6
+ import json
7
+ import requests
8
+ import zipfile
9
+ import io
10
+ import natsort
11
+ from PIL import Image as PilImage
12
+
13
+ from torchvision import datasets, transforms
14
+ from torchvision.transforms import CenterCrop, Normalize, Resize, ToTensor
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ from tqdm import tqdm
17
  from modeling_hybrid_clip import FlaxHybridCLIP
18
 
19
 
22
  return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
23
 
24
 
25
+ @st.cache
26
+ def download_images():
27
+ # from sentence_transformers import SentenceTransformer, util
28
+ img_folder = "photos/"
29
+ if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
30
+ os.makedirs(img_folder, exist_ok=True)
31
+
32
+ photo_filename = "unsplash-25k-photos.zip"
33
+ if not os.path.exists(photo_filename): # Download dataset if does not exist
34
+ print(f"Downloading {photo_filename}...")
35
+ r = requests.get("http://sbert.net/datasets/" + photo_filename, stream=True)
36
+ z = zipfile.ZipFile(io.BytesIO(r.content))
37
+ print("Extracting the dataset...")
38
+ z.extractall(path=img_folder)
39
+ print("Done.")
40
+
41
+
42
+ @st.cache
43
+ def get_image_features(model, image_dir):
44
+ image_size = model.config.vision_config.image_size
45
+
46
+ val_preprocess = transforms.Compose(
47
+ [
48
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
49
+ CenterCrop(image_size),
50
+ ToTensor(),
51
+ Normalize(
52
+ (0.48145466, 0.4578275, 0.40821073),
53
+ (0.26862954, 0.26130258, 0.27577711),
54
+ ),
55
+ ]
56
+ )
57
+
58
+ dataset = CustomDataSet(image_dir, transform=val_preprocess)
59
+
60
+ loader = torch.utils.data.DataLoader(
61
+ dataset,
62
+ batch_size=256,
63
+ shuffle=False,
64
+ num_workers=2,
65
+ persistent_workers=True,
66
+ drop_last=False,
67
+ )
68
+
69
+ return precompute_image_features(loader), dataset
70
+
71
+
72
+ class CustomDataSet(torch.utils.data.Dataset):
73
+ def __init__(self, main_dir, transform):
74
+ self.main_dir = main_dir
75
+ self.transform = transform
76
+ all_imgs = os.listdir(main_dir)
77
+ self.total_imgs = natsort.natsorted(all_imgs)
78
+
79
+ def __len__(self):
80
+ return len(self.total_imgs)
81
+
82
+ def get_image_name(self, idx):
83
+ return self.total_imgs[idx]
84
+
85
+ def __getitem__(self, idx):
86
+ img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
87
+ image = PilImage.open(img_loc).convert("RGB")
88
+ tensor_image = self.transform(image)
89
+ return tensor_image
90
+
91
+
92
+ def text_encoder(text, tokenizer):
93
+ inputs = tokenizer(
94
+ [text],
95
+ max_length=96,
96
+ truncation=True,
97
+ padding="max_length",
98
+ return_tensors="np",
99
+ )
100
+ embedding = model.get_text_features(inputs["input_ids"], inputs["attention_mask"])[
101
+ 0
102
+ ]
103
+ embedding /= jnp.linalg.norm(embedding)
104
+ return jnp.expand_dims(embedding, axis=0)
105
+
106
+
107
+ def precompute_image_features(loader):
108
+ image_features = []
109
+ for i, (images) in enumerate(tqdm(loader)):
110
+ images = images.permute(0, 2, 3, 1).numpy()
111
+ features = model.get_image_features(
112
+ images,
113
+ )
114
+ features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
115
+ image_features.extend(features)
116
+ return jnp.array(image_features)
117
+
118
+
119
+ def find_image(text_query, dataset, tokenizer, image_features, n=1):
120
+ zeroshot_weights = text_encoder(text_query, tokenizer)
121
+ zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
122
+ distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
123
+ file_paths = []
124
+ for i in range(1, n + 1):
125
+ idx = jnp.argsort(distances, axis=0)[-i, 0]
126
+ file_paths.append("photos/" + dataset.get_image_name(idx))
127
+ return file_paths
128
+
129
+
130
  """
131
  # CLIP Italian Demo (Flax Community Week)
132
  """
133
 
 
 
134
 
135
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
136
+
137
+
138
+ query = st.text_input("Insert a query text")
139
+ if query:
140
+
141
+ with st.spinner("Computing in progress..."):
142
+ model = get_model()
143
+ download_images()
144
+
145
+ tokenizer = AutoTokenizer.from_pretrained(
146
+ "dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True
147
+ )
148
+
149
+ image_features, dataset = get_image_features(model, "photos")
150
+
151
+ image_paths = find_image(query, dataset, tokenizer, image_features, n=3)
152
+
153
+ st.image(image_paths)
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
- jax==0.2.17
2
- flax==0.3.4
3
- transformers==4.8.2
 
 
 
1
+ jax
2
+ flax
3
+ transformers
4
+ torch
5
+ torchvision
6
+ natsort