Spaces:
Running
Running
Merge branch 'main' of https://huggingface.co/spaces/clip-italian/clip-italian-demo into main
Browse files- image2text.py +69 -1
- introduction.md +5 -5
- requirements.txt +2 -1
- static/CC_val_urls.txt +0 -0
- static/features/{cc_features.npy → CC_val_embeddings.npy} +2 -2
- text2image.py +100 -37
- utils.py +17 -5
image2text.py
CHANGED
@@ -1,4 +1,72 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
def app():
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from text2image import get_model, get_tokenizer, get_image_transform
|
3 |
+
from utils import text_encoder, image_encoder
|
4 |
+
from PIL import Image
|
5 |
+
from jax import numpy as jnp
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
|
9 |
def app():
|
10 |
+
st.title("From Image to Text")
|
11 |
+
st.markdown(
|
12 |
+
"""
|
13 |
+
|
14 |
+
### 👋 Ciao!
|
15 |
+
|
16 |
+
Here you can find the captions that are most related to a given image.
|
17 |
+
|
18 |
+
🤌 Italian mode on! 🤌
|
19 |
+
|
20 |
+
"""
|
21 |
+
)
|
22 |
+
|
23 |
+
filename = st.file_uploader(
|
24 |
+
"Choose an image from your computer", type=["jpg", "jpeg", "png"]
|
25 |
+
)
|
26 |
+
|
27 |
+
MAX_CAP = 4
|
28 |
+
|
29 |
+
col1, col2 = st.beta_columns([3, 1])
|
30 |
+
|
31 |
+
with col2:
|
32 |
+
captions_count = st.selectbox(
|
33 |
+
"Number of captions", options=range(1, MAX_CAP + 1)
|
34 |
+
)
|
35 |
+
compute = st.button("Compute")
|
36 |
+
|
37 |
+
with col1:
|
38 |
+
captions = list()
|
39 |
+
for idx in range(min(MAX_CAP, captions_count)):
|
40 |
+
captions.append(st.text_input(f"Insert Caption {idx+1}"))
|
41 |
+
|
42 |
+
if compute:
|
43 |
+
captions = [c for c in captions if c != ""]
|
44 |
+
|
45 |
+
if not captions or not filename:
|
46 |
+
st.error("Please choose one image and at least one caption")
|
47 |
+
else:
|
48 |
+
with st.spinner("Computing..."):
|
49 |
+
model = get_model()
|
50 |
+
tokenizer = get_tokenizer()
|
51 |
+
|
52 |
+
text_embeds = list()
|
53 |
+
for i, c in enumerate(captions):
|
54 |
+
text_embeds.extend(text_encoder(c, model, tokenizer))
|
55 |
+
|
56 |
+
text_embeds = jnp.array(text_embeds)
|
57 |
+
|
58 |
+
image = Image.open(filename).convert("RGB")
|
59 |
+
transform = get_image_transform(model.config.vision_config.image_size)
|
60 |
+
image_embed = image_encoder(transform(image), model)
|
61 |
+
|
62 |
+
# we could have a softmax here
|
63 |
+
cos_similarities = jnp.matmul(image_embed, text_embeds.T)
|
64 |
+
|
65 |
+
chart_data = pd.Series(cos_similarities[0], index=captions)
|
66 |
+
|
67 |
+
col1, col2 = st.beta_columns(2)
|
68 |
+
with col1:
|
69 |
+
st.bar_chart(chart_data)
|
70 |
+
|
71 |
+
with col2:
|
72 |
+
st.image(image)
|
introduction.md
CHANGED
@@ -54,6 +54,8 @@ a dataset with 700K translated captions.
|
|
54 |
|
55 |
## Better Augmentations
|
56 |
|
|
|
|
|
57 |
## Better Training
|
58 |
|
59 |
After different trials, we realized that the usual way of training this model was
|
@@ -62,17 +64,15 @@ training pipeline: the optimizer and the training with frozen components.
|
|
62 |
|
63 |
### Optimizer
|
64 |
|
65 |
-
|
66 |
Our implementation is available online [here](https://github.com/clip-italian/clip-italian/blob/master/hybrid_clip/run_hybrid_clip.py#L667).
|
67 |
|
68 |
### Backbone Freezing
|
69 |
|
70 |
The ViT used by OpenAI was already trained on 400million images and it is the element in our architecture that probably required less training.
|
71 |
-
The same is true for the BERT model we use.
|
72 |
-
the deeper layer to adapt to the new setting. Eventually, we run a new training, by fine-tuning al the components. This technique allowed us to
|
73 |
-
reach a much better validation loss.
|
74 |
|
75 |
-
<img src="https://huggingface.co/spaces/clip-italian/clip-italian-demo/raw/main/static/img/clip-italian.png" alt="drawing" width="
|
76 |
|
77 |
# Scientific Validity
|
78 |
|
|
|
54 |
|
55 |
## Better Augmentations
|
56 |
|
57 |
+
We knew that without a good augmentation strategy we could never get competitive results to a model trained on 400 million images. Therefor we implemented heavy augmentations to make the training more data efficient. We made sure to keep hue augmentations limited however to still give the model the ability to learn color definitions. While we would have liked to have augmentations for the captions as well after some experimentation we settled with random sampling from the five captions available in MSCOCO and leaving the rest of the captions unmodified.
|
58 |
+
|
59 |
## Better Training
|
60 |
|
61 |
After different trials, we realized that the usual way of training this model was
|
|
|
64 |
|
65 |
### Optimizer
|
66 |
|
67 |
+
While the initial code used AdamW as an optimizer we soon noticed that it introduced some bad properties into the training. The model strated to overfit relatively quickly and the weight decay made this effect worse. We eventually decided to an optimization strategy that had worked well for us in similar cases and used AdaBelief with Adaptive Gradient Clipping (AGC) and a Cosine Annealing Schedule. Together with slightly tuning the learning rate this helped us to reduce the validation loss by 25%.
|
68 |
Our implementation is available online [here](https://github.com/clip-italian/clip-italian/blob/master/hybrid_clip/run_hybrid_clip.py#L667).
|
69 |
|
70 |
### Backbone Freezing
|
71 |
|
72 |
The ViT used by OpenAI was already trained on 400million images and it is the element in our architecture that probably required less training.
|
73 |
+
The same is true for the BERT model we use. To allow the randomly initialized Re-projection Layers to warm up without messing with the tuned weights of the backbones we decided to do a first training with the backbones of our architecture completely frozen. Only after these layers converged did we unfreeze the rest of the model to fine-tune all the components. This technique allowed us to reach a much better validation loss.
|
|
|
|
|
74 |
|
75 |
+
<img src="https://huggingface.co/spaces/clip-italian/clip-italian-demo/raw/main/static/img/clip-italian.png" alt="drawing" width="50%"/>
|
76 |
|
77 |
# Scientific Validity
|
78 |
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ transformers
|
|
4 |
torch
|
5 |
torchvision
|
6 |
natsort
|
7 |
-
stqdm
|
|
|
|
4 |
torch
|
5 |
torchvision
|
6 |
natsort
|
7 |
+
stqdm
|
8 |
+
pandas
|
static/CC_val_urls.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
static/features/{cc_features.npy → CC_val_embeddings.npy}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:775803a42011b09e8f5d19fcbdd67123cc3447154e1f8e5990cae1bce4581662
|
3 |
+
size 27369600
|
text2image.py
CHANGED
@@ -22,9 +22,15 @@ def get_model():
|
|
22 |
return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
|
23 |
|
24 |
|
25 |
-
@st.cache(
|
|
|
|
|
|
|
|
|
26 |
def get_tokenizer():
|
27 |
-
return AutoTokenizer.from_pretrained(
|
|
|
|
|
28 |
|
29 |
|
30 |
@st.cache(suppress_st_warning=True)
|
@@ -37,10 +43,14 @@ def download_images():
|
|
37 |
photo_filename = "unsplash-25k-photos.zip"
|
38 |
if not os.path.exists(photo_filename): # Download dataset if does not exist
|
39 |
print(f"Downloading {photo_filename}...")
|
40 |
-
response = requests.get(
|
41 |
-
|
|
|
|
|
42 |
block_size = 1024 # 1 Kb
|
43 |
-
progress_bar = stqdm(
|
|
|
|
|
44 |
content = io.BytesIO()
|
45 |
for data in response.iter_content(block_size):
|
46 |
progress_bar.update(len(data))
|
@@ -54,53 +64,106 @@ def download_images():
|
|
54 |
|
55 |
|
56 |
@st.cache()
|
57 |
-
def get_image_features():
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
61 |
|
62 |
-
"""
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
-
In this demo you can search for images in the Unsplash 25k Photos dataset.
|
70 |
|
71 |
-
|
72 |
|
73 |
-
""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
query = st.text_input("Insert an italian query text here...")
|
76 |
if query:
|
77 |
-
with st.spinner("Computing
|
|
|
78 |
model = get_model()
|
79 |
-
download_images()
|
80 |
|
81 |
-
|
|
|
82 |
|
|
|
83 |
model = get_model()
|
84 |
tokenizer = get_tokenizer()
|
85 |
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
(0.26862954, 0.26130258, 0.27577711),
|
96 |
-
),
|
97 |
-
]
|
98 |
-
)
|
99 |
-
|
100 |
-
dataset = utils.CustomDataSet("photos/", transform=val_preprocess)
|
101 |
|
102 |
image_paths = utils.find_image(
|
103 |
-
query, model, dataset, tokenizer, image_features,
|
104 |
)
|
105 |
|
106 |
st.image(image_paths)
|
|
|
22 |
return FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian")
|
23 |
|
24 |
|
25 |
+
@st.cache(
|
26 |
+
hash_funcs={
|
27 |
+
transformers.models.bert.tokenization_bert_fast.BertTokenizerFast: lambda _: None
|
28 |
+
}
|
29 |
+
)
|
30 |
def get_tokenizer():
|
31 |
+
return AutoTokenizer.from_pretrained(
|
32 |
+
"dbmdz/bert-base-italian-xxl-uncased", cache_dir="./", use_fast=True
|
33 |
+
)
|
34 |
|
35 |
|
36 |
@st.cache(suppress_st_warning=True)
|
|
|
43 |
photo_filename = "unsplash-25k-photos.zip"
|
44 |
if not os.path.exists(photo_filename): # Download dataset if does not exist
|
45 |
print(f"Downloading {photo_filename}...")
|
46 |
+
response = requests.get(
|
47 |
+
f"http://sbert.net/datasets/{photo_filename}", stream=True
|
48 |
+
)
|
49 |
+
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
50 |
block_size = 1024 # 1 Kb
|
51 |
+
progress_bar = stqdm(
|
52 |
+
total=total_size_in_bytes
|
53 |
+
) # , unit='iB', unit_scale=True
|
54 |
content = io.BytesIO()
|
55 |
for data in response.iter_content(block_size):
|
56 |
progress_bar.update(len(data))
|
|
|
64 |
|
65 |
|
66 |
@st.cache()
|
67 |
+
def get_image_features(dataset_name):
|
68 |
+
if dataset_name == "Unsplash":
|
69 |
+
return jnp.load("static/features/features.npy")
|
70 |
+
else:
|
71 |
+
return jnp.load("static/features/CC_val_embeddings.npy")
|
72 |
|
|
|
73 |
|
74 |
+
@st.cache()
|
75 |
+
def load_urls(dataset_name):
|
76 |
+
if dataset_name == "CC":
|
77 |
+
with open("static/CC_val_urls.txt") as fp:
|
78 |
+
urls = [l.strip() for l in fp.readlines()]
|
79 |
+
return urls
|
80 |
+
else:
|
81 |
+
ValueError(f"{dataset_name} not supported here")
|
82 |
+
|
83 |
+
|
84 |
+
def get_image_transform(image_size):
|
85 |
+
return Compose(
|
86 |
+
[
|
87 |
+
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
|
88 |
+
CenterCrop(image_size),
|
89 |
+
ToTensor(),
|
90 |
+
Normalize(
|
91 |
+
(0.48145466, 0.4578275, 0.40821073),
|
92 |
+
(0.26862954, 0.26130258, 0.27577711),
|
93 |
+
),
|
94 |
+
]
|
95 |
+
)
|
96 |
|
|
|
97 |
|
98 |
+
def app():
|
99 |
|
100 |
+
st.title("From Text to Image")
|
101 |
+
st.markdown(
|
102 |
+
"""
|
103 |
+
|
104 |
+
### 👋 Ciao!
|
105 |
+
|
106 |
+
Here you can search for images in the Unsplash 25k Photos dataset.
|
107 |
+
|
108 |
+
🤌 Italian mode on! 🤌
|
109 |
+
|
110 |
+
"""
|
111 |
+
)
|
112 |
+
|
113 |
+
if "suggestion" not in st.session_state:
|
114 |
+
st.session_state.suggestion = ""
|
115 |
+
|
116 |
+
def update_query(value=""):
|
117 |
+
st.session_state.suggestion = value
|
118 |
+
|
119 |
+
col1, col2, col3, col4 = st.beta_columns(4)
|
120 |
+
with col1:
|
121 |
+
st.button("Un gatto", on_click=update_query, kwargs=dict(value="Un gatto"))
|
122 |
+
with col2:
|
123 |
+
st.button("Due gatti", on_click=update_query, kwargs=dict(value="Due gatti"))
|
124 |
+
with col3:
|
125 |
+
st.button(
|
126 |
+
"Un fiore giallo",
|
127 |
+
on_click=update_query,
|
128 |
+
kwargs=dict(value="Un fiore giallo"),
|
129 |
+
)
|
130 |
+
with col4:
|
131 |
+
st.button(
|
132 |
+
"Un fiore blu", on_click=update_query, kwargs=dict(value="Un fiore blu")
|
133 |
+
)
|
134 |
+
|
135 |
+
col1, col2 = st.beta_columns([3, 1])
|
136 |
+
with col1:
|
137 |
+
query = st.text_input(
|
138 |
+
"Insert an italian query text here...", st.session_state.suggestion
|
139 |
+
)
|
140 |
+
with col2:
|
141 |
+
dataset_name = st.selectbox("IR dataset", ["Unsplash", "CC"])
|
142 |
|
|
|
143 |
if query:
|
144 |
+
with st.spinner("Computing..."):
|
145 |
+
|
146 |
model = get_model()
|
|
|
147 |
|
148 |
+
if dataset_name == "Unsplash":
|
149 |
+
download_images()
|
150 |
|
151 |
+
image_features = get_image_features(dataset_name)
|
152 |
model = get_model()
|
153 |
tokenizer = get_tokenizer()
|
154 |
|
155 |
+
if dataset_name == "Unsplash":
|
156 |
+
image_size = model.config.vision_config.image_size
|
157 |
+
dataset = utils.CustomDataSet(
|
158 |
+
"photos/", transform=get_image_transform(image_size)
|
159 |
+
)
|
160 |
+
elif dataset_name == "CC":
|
161 |
+
dataset = load_urls(dataset_name)
|
162 |
+
else:
|
163 |
+
raise ValueError()
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
image_paths = utils.find_image(
|
166 |
+
query, model, dataset, tokenizer, image_features, 2, dataset_name
|
167 |
)
|
168 |
|
169 |
st.image(image_paths)
|
utils.py
CHANGED
@@ -41,24 +41,36 @@ def text_encoder(text, model, tokenizer):
|
|
41 |
return jnp.expand_dims(embedding, axis=0)
|
42 |
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
def precompute_image_features(model, loader):
|
45 |
image_features = []
|
46 |
for i, (images) in enumerate(tqdm(loader)):
|
47 |
images = images.permute(0, 2, 3, 1).numpy()
|
48 |
-
features = model.get_image_features(
|
49 |
-
images,
|
50 |
-
)
|
51 |
features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
|
52 |
image_features.extend(features)
|
53 |
return jnp.array(image_features)
|
54 |
|
55 |
|
56 |
-
def find_image(text_query, model, dataset, tokenizer, image_features, n
|
57 |
zeroshot_weights = text_encoder(text_query, model, tokenizer)
|
58 |
zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
|
59 |
distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
|
60 |
file_paths = []
|
61 |
for i in range(1, n + 1):
|
62 |
idx = jnp.argsort(distances, axis=0)[-i, 0]
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
return file_paths
|
|
|
41 |
return jnp.expand_dims(embedding, axis=0)
|
42 |
|
43 |
|
44 |
+
def image_encoder(image, model):
|
45 |
+
image = image.permute(1, 2, 0).numpy()
|
46 |
+
image = jnp.expand_dims(image, axis=0) # add batch size
|
47 |
+
features = model.get_image_features(image,)
|
48 |
+
features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
|
49 |
+
return features
|
50 |
+
|
51 |
+
|
52 |
def precompute_image_features(model, loader):
|
53 |
image_features = []
|
54 |
for i, (images) in enumerate(tqdm(loader)):
|
55 |
images = images.permute(0, 2, 3, 1).numpy()
|
56 |
+
features = model.get_image_features(images,)
|
|
|
|
|
57 |
features /= jnp.linalg.norm(features, axis=-1, keepdims=True)
|
58 |
image_features.extend(features)
|
59 |
return jnp.array(image_features)
|
60 |
|
61 |
|
62 |
+
def find_image(text_query, model, dataset, tokenizer, image_features, n, dataset_name):
|
63 |
zeroshot_weights = text_encoder(text_query, model, tokenizer)
|
64 |
zeroshot_weights /= jnp.linalg.norm(zeroshot_weights)
|
65 |
distances = jnp.dot(image_features, zeroshot_weights.reshape(-1, 1))
|
66 |
file_paths = []
|
67 |
for i in range(1, n + 1):
|
68 |
idx = jnp.argsort(distances, axis=0)[-i, 0]
|
69 |
+
|
70 |
+
if dataset_name == "Unsplash":
|
71 |
+
file_paths.append("photos/" + dataset.get_image_name(idx))
|
72 |
+
elif dataset_name == "CC":
|
73 |
+
file_paths.append(dataset[idx])
|
74 |
+
else:
|
75 |
+
raise ValueError(f"{dataset_name} not supported here")
|
76 |
return file_paths
|