g8a9 commited on
Commit
c5ad46a
1 Parent(s): 0789e97

[text2image] Add IR for the CC validation set

Browse files
static/CC_val_urls.txt ADDED
The diff for this file is too large to render. See raw diff
 
static/features/{cc_features.npy → CC_val_embeddings.npy} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:63f185e851ff9cd0a19c5b1877087d860ca53ec5fc9e6a7d608249b9aacb77df
3
- size 2050773120
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:775803a42011b09e8f5d19fcbdd67123cc3447154e1f8e5990cae1bce4581662
3
+ size 27369600
text2image.py CHANGED
@@ -22,9 +22,15 @@ def get_model():
22
  return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
23
 
24
 
25
- @st.cache(hash_funcs={transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None})
 
 
 
 
26
  def get_tokenizer():
27
- return AutoTokenizer.from_pretrained("dbmdz/bert-base-italian-xxl-uncased", cache_dir="./", use_fast=True)
 
 
28
 
29
 
30
  @st.cache(suppress_st_warning=True)
@@ -37,10 +43,14 @@ def download_images():
37
  photo_filename = "unsplash-25k-photos.zip"
38
  if not os.path.exists(photo_filename): # Download dataset if does not exist
39
  print(f"Downloading {photo_filename}...")
40
- response = requests.get(f"http://sbert.net/datasets/{photo_filename}", stream=True)
41
- total_size_in_bytes = int(response.headers.get('content-length', 0))
 
 
42
  block_size = 1024 # 1 Kb
43
- progress_bar = stqdm(total=total_size_in_bytes) # , unit='iB', unit_scale=True
 
 
44
  content = io.BytesIO()
45
  for data in response.iter_content(block_size):
46
  progress_bar.update(len(data))
@@ -54,8 +64,21 @@ def download_images():
54
 
55
 
56
  @st.cache()
57
- def get_image_features():
58
- return jnp.load("static/features/features.npy")
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  def app():
@@ -73,7 +96,7 @@ def app():
73
  """
74
  )
75
 
76
- if 'suggestion' not in st.session_state:
77
  st.session_state.suggestion = ""
78
 
79
  def update_query(value=""):
@@ -81,44 +104,61 @@ def app():
81
 
82
  col1, col2, col3, col4 = st.beta_columns(4)
83
  with col1:
84
- st.button('Un gatto', on_click=update_query, kwargs=dict(value='Un gatto'))
85
  with col2:
86
- st.button('Due gatti', on_click=update_query, kwargs=dict(value='Due gatti'))
87
  with col3:
88
- st.button('Un fiore giallo', on_click=update_query, kwargs=dict(value='Un fiore giallo'))
 
 
 
 
89
  with col4:
90
- st.button('Un fiore blu', on_click=update_query, kwargs=dict(value='Un fiore blu'))
 
 
91
 
92
- query = st.text_input('Insert an italian query text here...', st.session_state.suggestion)
 
 
 
 
 
 
93
 
94
  if query:
95
- with st.spinner("Computing in progress..."):
 
96
  model = get_model()
97
- download_images()
98
 
99
- image_features = get_image_features()
 
100
 
 
101
  model = get_model()
102
  tokenizer = get_tokenizer()
103
 
104
- image_size = model.config.vision_config.image_size
105
-
106
- val_preprocess = Compose(
107
- [
108
- Resize([image_size], interpolation=InterpolationMode.BICUBIC),
109
- CenterCrop(image_size),
110
- ToTensor(),
111
- Normalize(
112
- (0.48145466, 0.4578275, 0.40821073),
113
- (0.26862954, 0.26130258, 0.27577711),
114
- ),
115
- ]
116
- )
117
-
118
- dataset = utils.CustomDataSet("photos/", transform=val_preprocess)
 
 
 
119
 
120
  image_paths = utils.find_image(
121
- query, model, dataset, tokenizer, image_features, n=2
122
  )
123
 
124
  st.image(image_paths)
 
22
  return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
23
 
24
 
25
+ @st.cache(
26
+ hash_funcs={
27
+ transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None
28
+ }
29
+ )
30
  def get_tokenizer():
31
+ return AutoTokenizer.from_pretrained(
32
+ "dbmdz/bert-base-italian-xxl-uncased", cache_dir="./", use_fast=True
33
+ )
34
 
35
 
36
  @st.cache(suppress_st_warning=True)
 
43
  photo_filename = "unsplash-25k-photos.zip"
44
  if not os.path.exists(photo_filename): # Download dataset if does not exist
45
  print(f"Downloading {photo_filename}...")
46
+ response = requests.get(
47
+ f"http://sbert.net/datasets/{photo_filename}", stream=True
48
+ )
49
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
50
  block_size = 1024 # 1 Kb
51
+ progress_bar = stqdm(
52
+ total=total_size_in_bytes
53
+ ) # , unit='iB', unit_scale=True
54
  content = io.BytesIO()
55
  for data in response.iter_content(block_size):
56
  progress_bar.update(len(data))
 
64
 
65
 
66
  @st.cache()
67
+ def get_image_features(dataset_name):
68
+ if dataset_name == "Unsplash":
69
+ return jnp.load("static/features/features.npy")
70
+ else:
71
+ return jnp.load("static/features/CC_val_embeddings.npy")
72
+
73
+
74
+ @st.cache()
75
+ def load_urls(dataset_name):
76
+ if dataset_name == "CC":
77
+ with open("static/CC_val_urls.txt") as fp:
78
+ urls = [l.strip() for l in fp.readlines()]
79
+ return urls
80
+ else:
81
+ ValueError(f"{dataset_name} not supported here")
82
 
83
 
84
  def app():
 
96
  """
97
  )
98
 
99
+ if "suggestion" not in st.session_state:
100
  st.session_state.suggestion = ""
101
 
102
  def update_query(value=""):
 
104
 
105
  col1, col2, col3, col4 = st.beta_columns(4)
106
  with col1:
107
+ st.button("Un gatto", on_click=update_query, kwargs=dict(value="Un gatto"))
108
  with col2:
109
+ st.button("Due gatti", on_click=update_query, kwargs=dict(value="Due gatti"))
110
  with col3:
111
+ st.button(
112
+ "Un fiore giallo",
113
+ on_click=update_query,
114
+ kwargs=dict(value="Un fiore giallo"),
115
+ )
116
  with col4:
117
+ st.button(
118
+ "Un fiore blu", on_click=update_query, kwargs=dict(value="Un fiore blu")
119
+ )
120
 
121
+ col1, col2 = st.beta_columns([3, 1])
122
+ with col1:
123
+ query = st.text_input(
124
+ "Insert an italian query text here...", st.session_state.suggestion
125
+ )
126
+ with col2:
127
+ dataset_name = st.selectbox("IR dataset", ["Unsplash", "CC"])
128
 
129
  if query:
130
+ with st.spinner("Computing..."):
131
+
132
  model = get_model()
 
133
 
134
+ if dataset_name == "Unsplash":
135
+ download_images()
136
 
137
+ image_features = get_image_features(dataset_name)
138
  model = get_model()
139
  tokenizer = get_tokenizer()
140
 
141
+ if dataset_name == "Unsplash":
142
+ image_size = model.config.vision_config.image_size
143
+ val_preprocess = Compose(
144
+ [
145
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
146
+ CenterCrop(image_size),
147
+ ToTensor(),
148
+ Normalize(
149
+ (0.48145466, 0.4578275, 0.40821073),
150
+ (0.26862954, 0.26130258, 0.27577711),
151
+ ),
152
+ ]
153
+ )
154
+ dataset = utils.CustomDataSet("photos/", transform=val_preprocess)
155
+ elif dataset_name == "CC":
156
+ dataset = load_urls(dataset_name)
157
+ else:
158
+ raise ValueError()
159
 
160
  image_paths = utils.find_image(
161
+ query, model, dataset, tokenizer, image_features, 2, dataset_name
162
  )
163
 
164
  st.image(image_paths)
utils.py CHANGED
@@ -45,20 +45,24 @@ def precompute_image_features(model, loader):
45
  image_features = []
46
  for i, (images) in enumerate(tqdm(loader)):
47
  images = images.permute(0, 2, 3, 1).numpy()
48
- features = model.get_image_features(
49
- images,
50
- )
51
  features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
52
  image_features.extend(features)
53
  return jnp.array(image_features)
54
 
55
 
56
- def find_image(text_query, model, dataset, tokenizer, image_features, n=1):
57
  zeroshot_weights = text_encoder(text_query, model, tokenizer)
58
  zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
59
  distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
60
  file_paths = []
61
  for i in range(1, n + 1):
62
  idx = jnp.argsort(distances, axis=0)[-i, 0]
63
- file_paths.append("photos/" + dataset.get_image_name(idx))
 
 
 
 
 
 
64
  return file_paths
 
45
  image_features = []
46
  for i, (images) in enumerate(tqdm(loader)):
47
  images = images.permute(0, 2, 3, 1).numpy()
48
+ features = model.get_image_features(images,)
 
 
49
  features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
50
  image_features.extend(features)
51
  return jnp.array(image_features)
52
 
53
 
54
+ def find_image(text_query, model, dataset, tokenizer, image_features, n, dataset_name):
55
  zeroshot_weights = text_encoder(text_query, model, tokenizer)
56
  zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
57
  distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
58
  file_paths = []
59
  for i in range(1, n + 1):
60
  idx = jnp.argsort(distances, axis=0)[-i, 0]
61
+
62
+ if dataset_name == "Unsplash":
63
+ file_paths.append("photos/" + dataset.get_image_name(idx))
64
+ elif dataset_name == "CC":
65
+ file_paths.append(dataset[idx])
66
+ else:
67
+ raise ValueError(f"{dataset_name} not supported here")
68
  return file_paths