vinid commited on
Commit
a5b18fc
2 Parent(s): 3c6a443 a01e989

Merge branch 'main' of https://huggingface.co/spaces/clip-italian/clip-italian-demo into main

Browse files
image2text.py CHANGED
@@ -1,4 +1,72 @@
1
  import streamlit as st
 
 
 
 
 
 
2
 
3
  def app():
4
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from text2image import get_model, get_tokenizer, get_image_transform
3
+ from utils import text_encoder, image_encoder
4
+ from PIL import Image
5
+ from jax import numpy as jnp
6
+ import pandas as pd
7
+
8
 
9
  def app():
10
+ st.title("From Image to Text")
11
+ st.markdown(
12
+ """
13
+
14
+ ### 👋 Ciao!
15
+
16
+ Here you can find the captions that are most related to a given image.
17
+
18
+ 🤌 Italian mode on! 🤌
19
+
20
+ """
21
+ )
22
+
23
+ filename = st.file_uploader(
24
+ "Choose an image from your computer", type=["jpg", "jpeg", "png"]
25
+ )
26
+
27
+ MAX_CAP = 4
28
+
29
+ col1, col2 = st.beta_columns([3, 1])
30
+
31
+ with col2:
32
+ captions_count = st.selectbox(
33
+ "Number of captions", options=range(1, MAX_CAP + 1)
34
+ )
35
+ compute = st.button("Compute")
36
+
37
+ with col1:
38
+ captions = list()
39
+ for idx in range(min(MAX_CAP, captions_count)):
40
+ captions.append(st.text_input(f"Insert Caption {idx+1}"))
41
+
42
+ if compute:
43
+ captions = [c for c in captions if c != ""]
44
+
45
+ if not captions or not filename:
46
+ st.error("Please choose one image and at least one caption")
47
+ else:
48
+ with st.spinner("Computing..."):
49
+ model = get_model()
50
+ tokenizer = get_tokenizer()
51
+
52
+ text_embeds = list()
53
+ for i, c in enumerate(captions):
54
+ text_embeds.extend(text_encoder(c, model, tokenizer))
55
+
56
+ text_embeds = jnp.array(text_embeds)
57
+
58
+ image = Image.open(filename).convert("RGB")
59
+ transform = get_image_transform(model.config.vision_config.image_size)
60
+ image_embed = image_encoder(transform(image), model)
61
+
62
+ # we could have a softmax here
63
+ cos_similarities = jnp.matmul(image_embed, text_embeds.T)
64
+
65
+ chart_data = pd.Series(cos_similarities[0], index=captions)
66
+
67
+ col1, col2 = st.beta_columns(2)
68
+ with col1:
69
+ st.bar_chart(chart_data)
70
+
71
+ with col2:
72
+ st.image(image)
introduction.md CHANGED
@@ -54,6 +54,8 @@ a dataset with 700K translated captions.
54
 
55
  ## Better Augmentations
56
 
 
 
57
  ## Better Training
58
 
59
  After different trials, we realized that the usual way of training this model was
@@ -62,17 +64,15 @@ training pipeline: the optimizer and the training with frozen components.
62
 
63
  ### Optimizer
64
 
65
- The standard AdamW didn't seem enough to train the model and thus we opted for a different optimization strategy. We eventually used AdaBelief with AGC and Cosine Annealing.
66
  Our implementation is available online [here](https://github.com/clip-italian/clip-italian/blob/master/hybrid_clip/run_hybrid_clip.py#L667).
67
 
68
  ### Backbone Freezing
69
 
70
  The ViT used by OpenAI was already trained on 400million images and it is the element in our architecture that probably required less training.
71
- The same is true for the BERT model we use. Thus, we decided to do a first training with the backbone of our architecture completely frozen, to allow
72
- the deeper layer to adapt to the new setting. Eventually, we run a new training, by fine-tuning al the components. This technique allowed us to
73
- reach a much better validation loss.
74
 
75
- <img src="https://huggingface.co/spaces/clip-italian/clip-italian-demo/raw/main/static/img/clip-italian.png" alt="drawing" width="600"/>
76
 
77
  # Scientific Validity
78
 
 
54
 
55
  ## Better Augmentations
56
 
57
+ We knew that without a good augmentation strategy we could never get competitive results to a model trained on 400 million images. Therefor we implemented heavy augmentations to make the training more data efficient. We made sure to keep hue augmentations limited however to still give the model the ability to learn color definitions. While we would have liked to have augmentations for the captions as well after some experimentation we settled with random sampling from the five captions available in MSCOCO and leaving the rest of the captions unmodified.
58
+
59
  ## Better Training
60
 
61
  After different trials, we realized that the usual way of training this model was
 
64
 
65
  ### Optimizer
66
 
67
+ While the initial code used AdamW as an optimizer we soon noticed that it introduced some bad properties into the training. The model strated to overfit relatively quickly and the weight decay made this effect worse. We eventually decided to an optimization strategy that had worked well for us in similar cases and used AdaBelief with Adaptive Gradient Clipping (AGC) and a Cosine Annealing Schedule. Together with slightly tuning the learning rate this helped us to reduce the validation loss by 25%.
68
  Our implementation is available online [here](https://github.com/clip-italian/clip-italian/blob/master/hybrid_clip/run_hybrid_clip.py#L667).
69
 
70
  ### Backbone Freezing
71
 
72
  The ViT used by OpenAI was already trained on 400million images and it is the element in our architecture that probably required less training.
73
+ The same is true for the BERT model we use. To allow the randomly initialized Re-projection Layers to warm up without messing with the tuned weights of the backbones we decided to do a first training with the backbones of our architecture completely frozen. Only after these layers converged did we unfreeze the rest of the model to fine-tune all the components. This technique allowed us to reach a much better validation loss.
 
 
74
 
75
+ <img src="https://huggingface.co/spaces/clip-italian/clip-italian-demo/raw/main/static/img/clip-italian.png" alt="drawing" width="50%"/>
76
 
77
  # Scientific Validity
78
 
requirements.txt CHANGED
@@ -4,4 +4,5 @@ transformers
4
  torch
5
  torchvision
6
  natsort
7
- stqdm
 
 
4
  torch
5
  torchvision
6
  natsort
7
+ stqdm
8
+ pandas
static/CC_val_urls.txt ADDED
The diff for this file is too large to render. See raw diff
 
static/features/{cc_features.npy → CC_val_embeddings.npy} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:63f185e851ff9cd0a19c5b1877087d860ca53ec5fc9e6a7d608249b9aacb77df
3
- size 2050773120
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:775803a42011b09e8f5d19fcbdd67123cc3447154e1f8e5990cae1bce4581662
3
+ size 27369600
text2image.py CHANGED
@@ -22,9 +22,15 @@ def get_model():
22
  return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
23
 
24
 
25
- @st.cache(hash_funcs={transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None})
 
 
 
 
26
  def get_tokenizer():
27
- return AutoTokenizer.from_pretrained("dbmdz/bert-base-italian-xxl-uncased", cache_dir="./", use_fast=True)
 
 
28
 
29
 
30
  @st.cache(suppress_st_warning=True)
@@ -37,10 +43,14 @@ def download_images():
37
  photo_filename = "unsplash-25k-photos.zip"
38
  if not os.path.exists(photo_filename): # Download dataset if does not exist
39
  print(f"Downloading {photo_filename}...")
40
- response = requests.get(f"http://sbert.net/datasets/{photo_filename}", stream=True)
41
- total_size_in_bytes = int(response.headers.get('content-length', 0))
 
 
42
  block_size = 1024 # 1 Kb
43
- progress_bar = stqdm(total=total_size_in_bytes) # , unit='iB', unit_scale=True
 
 
44
  content = io.BytesIO()
45
  for data in response.iter_content(block_size):
46
  progress_bar.update(len(data))
@@ -54,53 +64,106 @@ def download_images():
54
 
55
 
56
  @st.cache()
57
- def get_image_features():
58
- return jnp.load("static/features/features.npy")
59
-
60
- def app():
 
61
 
62
- """
63
 
64
- # 👋 Ciao!
65
-
66
- # CLIP Italian Demo
67
- ## HF-Flax Community Week
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- In this demo you can search for images in the Unsplash 25k Photos dataset.
70
 
71
- 🤌 Italian mode on! 🤌
72
 
73
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- query = st.text_input("Insert an italian query text here...")
76
  if query:
77
- with st.spinner("Computing in progress..."):
 
78
  model = get_model()
79
- download_images()
80
 
81
- image_features = get_image_features()
 
82
 
 
83
  model = get_model()
84
  tokenizer = get_tokenizer()
85
 
86
- image_size = model.config.vision_config.image_size
87
-
88
- val_preprocess = Compose(
89
- [
90
- Resize([image_size], interpolation=InterpolationMode.BICUBIC),
91
- CenterCrop(image_size),
92
- ToTensor(),
93
- Normalize(
94
- (0.48145466, 0.4578275, 0.40821073),
95
- (0.26862954, 0.26130258, 0.27577711),
96
- ),
97
- ]
98
- )
99
-
100
- dataset = utils.CustomDataSet("photos/", transform=val_preprocess)
101
 
102
  image_paths = utils.find_image(
103
- query, model, dataset, tokenizer, image_features, n=2
104
  )
105
 
106
  st.image(image_paths)
 
22
  return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
23
 
24
 
25
+ @st.cache(
26
+ hash_funcs={
27
+ transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None
28
+ }
29
+ )
30
  def get_tokenizer():
31
+ return AutoTokenizer.from_pretrained(
32
+ "dbmdz/bert-base-italian-xxl-uncased", cache_dir="./", use_fast=True
33
+ )
34
 
35
 
36
  @st.cache(suppress_st_warning=True)
 
43
  photo_filename = "unsplash-25k-photos.zip"
44
  if not os.path.exists(photo_filename): # Download dataset if does not exist
45
  print(f"Downloading {photo_filename}...")
46
+ response = requests.get(
47
+ f"http://sbert.net/datasets/{photo_filename}", stream=True
48
+ )
49
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
50
  block_size = 1024 # 1 Kb
51
+ progress_bar = stqdm(
52
+ total=total_size_in_bytes
53
+ ) # , unit='iB', unit_scale=True
54
  content = io.BytesIO()
55
  for data in response.iter_content(block_size):
56
  progress_bar.update(len(data))
 
64
 
65
 
66
  @st.cache()
67
+ def get_image_features(dataset_name):
68
+ if dataset_name == "Unsplash":
69
+ return jnp.load("static/features/features.npy")
70
+ else:
71
+ return jnp.load("static/features/CC_val_embeddings.npy")
72
 
 
73
 
74
+ @st.cache()
75
+ def load_urls(dataset_name):
76
+ if dataset_name == "CC":
77
+ with open("static/CC_val_urls.txt") as fp:
78
+ urls = [l.strip() for l in fp.readlines()]
79
+ return urls
80
+ else:
81
+ ValueError(f"{dataset_name} not supported here")
82
+
83
+
84
+ def get_image_transform(image_size):
85
+ return Compose(
86
+ [
87
+ Resize([image_size], interpolation=InterpolationMode.BICUBIC),
88
+ CenterCrop(image_size),
89
+ ToTensor(),
90
+ Normalize(
91
+ (0.48145466, 0.4578275, 0.40821073),
92
+ (0.26862954, 0.26130258, 0.27577711),
93
+ ),
94
+ ]
95
+ )
96
 
 
97
 
98
+ def app():
99
 
100
+ st.title("From Text to Image")
101
+ st.markdown(
102
+ """
103
+
104
+ ### 👋 Ciao!
105
+
106
+ Here you can search for images in the Unsplash 25k Photos dataset.
107
+
108
+ 🤌 Italian mode on! 🤌
109
+
110
+ """
111
+ )
112
+
113
+ if "suggestion" not in st.session_state:
114
+ st.session_state.suggestion = ""
115
+
116
+ def update_query(value=""):
117
+ st.session_state.suggestion = value
118
+
119
+ col1, col2, col3, col4 = st.beta_columns(4)
120
+ with col1:
121
+ st.button("Un gatto", on_click=update_query, kwargs=dict(value="Un gatto"))
122
+ with col2:
123
+ st.button("Due gatti", on_click=update_query, kwargs=dict(value="Due gatti"))
124
+ with col3:
125
+ st.button(
126
+ "Un fiore giallo",
127
+ on_click=update_query,
128
+ kwargs=dict(value="Un fiore giallo"),
129
+ )
130
+ with col4:
131
+ st.button(
132
+ "Un fiore blu", on_click=update_query, kwargs=dict(value="Un fiore blu")
133
+ )
134
+
135
+ col1, col2 = st.beta_columns([3, 1])
136
+ with col1:
137
+ query = st.text_input(
138
+ "Insert an italian query text here...", st.session_state.suggestion
139
+ )
140
+ with col2:
141
+ dataset_name = st.selectbox("IR dataset", ["Unsplash", "CC"])
142
 
 
143
  if query:
144
+ with st.spinner("Computing..."):
145
+
146
  model = get_model()
 
147
 
148
+ if dataset_name == "Unsplash":
149
+ download_images()
150
 
151
+ image_features = get_image_features(dataset_name)
152
  model = get_model()
153
  tokenizer = get_tokenizer()
154
 
155
+ if dataset_name == "Unsplash":
156
+ image_size = model.config.vision_config.image_size
157
+ dataset = utils.CustomDataSet(
158
+ "photos/", transform=get_image_transform(image_size)
159
+ )
160
+ elif dataset_name == "CC":
161
+ dataset = load_urls(dataset_name)
162
+ else:
163
+ raise ValueError()
 
 
 
 
 
 
164
 
165
  image_paths = utils.find_image(
166
+ query, model, dataset, tokenizer, image_features, 2, dataset_name
167
  )
168
 
169
  st.image(image_paths)
utils.py CHANGED
@@ -41,24 +41,36 @@ def text_encoder(text, model, tokenizer):
41
  return jnp.expand_dims(embedding, axis=0)
42
 
43
 
 
 
 
 
 
 
 
 
44
  def precompute_image_features(model, loader):
45
  image_features = []
46
  for i, (images) in enumerate(tqdm(loader)):
47
  images = images.permute(0, 2, 3, 1).numpy()
48
- features = model.get_image_features(
49
- images,
50
- )
51
  features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
52
  image_features.extend(features)
53
  return jnp.array(image_features)
54
 
55
 
56
- def find_image(text_query, model, dataset, tokenizer, image_features, n=1):
57
  zeroshot_weights = text_encoder(text_query, model, tokenizer)
58
  zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
59
  distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
60
  file_paths = []
61
  for i in range(1, n + 1):
62
  idx = jnp.argsort(distances, axis=0)[-i, 0]
63
- file_paths.append("photos/" + dataset.get_image_name(idx))
 
 
 
 
 
 
64
  return file_paths
 
41
  return jnp.expand_dims(embedding, axis=0)
42
 
43
 
44
+ def image_encoder(image, model):
45
+ image = image.permute(1, 2, 0).numpy()
46
+ image = jnp.expand_dims(image, axis=0) #  add batch size
47
+ features = model.get_image_features(image,)
48
+ features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
49
+ return features
50
+
51
+
52
  def precompute_image_features(model, loader):
53
  image_features = []
54
  for i, (images) in enumerate(tqdm(loader)):
55
  images = images.permute(0, 2, 3, 1).numpy()
56
+ features = model.get_image_features(images,)
 
 
57
  features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
58
  image_features.extend(features)
59
  return jnp.array(image_features)
60
 
61
 
62
+ def find_image(text_query, model, dataset, tokenizer, image_features, n, dataset_name):
63
  zeroshot_weights = text_encoder(text_query, model, tokenizer)
64
  zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
65
  distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
66
  file_paths = []
67
  for i in range(1, n + 1):
68
  idx = jnp.argsort(distances, axis=0)[-i, 0]
69
+
70
+ if dataset_name == "Unsplash":
71
+ file_paths.append("photos/" + dataset.get_image_name(idx))
72
+ elif dataset_name == "CC":
73
+ file_paths.append(dataset[idx])
74
+ else:
75
+ raise ValueError(f"{dataset_name} not supported here")
76
  return file_paths