Vivien
commited on
Commit
•
74e4bcd
1
Parent(s):
383bcb1
Create app
Browse files- .gitattributes +1 -0
- .gitignore +1 -0
- README.md +6 -5
- app.py +199 -0
- bpe_simple_vocab_16e6.txt.gz +3 -0
- data.csv +0 -0
- data2.csv +0 -0
- embeddings.npy +3 -0
- embeddings2.npy +3 -0
- embeddings2_slip_large.npy +3 -0
- embeddings_slip_large.npy +3 -0
- losses.py +132 -0
- models.py +331 -0
- requirements.txt +6 -0
- tokenizer.py +157 -0
- utils.py +213 -0
.gitattributes
CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
.vscode/
|
README.md
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
|
|
7 |
app_file: app.py
|
8 |
-
pinned:
|
9 |
---
|
10 |
|
11 |
# Configuration
|
|
|
1 |
---
|
2 |
+
title: Comparing CLIP and SLIP
|
3 |
+
emoji: 🖼️
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: blue
|
6 |
sdk: streamlit
|
7 |
+
sdk_version: 1.0.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
---
|
11 |
|
12 |
# Configuration
|
app.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import urllib.request
|
3 |
+
from collections import OrderedDict
|
4 |
+
from html import escape
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
|
12 |
+
from transformers import CLIPProcessor, CLIPModel
|
13 |
+
import tokenizers
|
14 |
+
import regex
|
15 |
+
|
16 |
+
import streamlit as st
|
17 |
+
|
18 |
+
import models
|
19 |
+
from tokenizer import SimpleTokenizer
|
20 |
+
|
21 |
+
cuda_available = torch.cuda.is_available()
|
22 |
+
|
23 |
+
model_url = "https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt"
|
24 |
+
model_filename = "slip_large_100ep.pt"
|
25 |
+
|
26 |
+
|
27 |
+
def get_model(model):
|
28 |
+
if isinstance(model, torch.nn.DataParallel) or isinstance(
|
29 |
+
model, torch.nn.parallel.DistributedDataParallel
|
30 |
+
):
|
31 |
+
return model.module
|
32 |
+
else:
|
33 |
+
return model
|
34 |
+
|
35 |
+
|
36 |
+
@st.cache(
|
37 |
+
show_spinner=False,
|
38 |
+
hash_funcs={
|
39 |
+
CLIPModel: lambda _: None,
|
40 |
+
CLIPProcessor: lambda _: None,
|
41 |
+
dict: lambda _: None,
|
42 |
+
},
|
43 |
+
)
|
44 |
+
def load():
|
45 |
+
# Load SLIP model from Facebook AI Research
|
46 |
+
if model_filename not in os.listdir():
|
47 |
+
urllib.request.urlretrieve(model_url, model_filename)
|
48 |
+
ckpt = torch.load("slip_large_100ep.pt", map_location="cpu")
|
49 |
+
state_dict = OrderedDict()
|
50 |
+
for k, v in ckpt["state_dict"].items():
|
51 |
+
state_dict[k.replace("module.", "")] = v
|
52 |
+
old_args = ckpt["args"]
|
53 |
+
slip_model = getattr(models, "SLIP_VITL16")(
|
54 |
+
rand_embed=False,
|
55 |
+
ssl_mlp_dim=old_args.ssl_mlp_dim,
|
56 |
+
ssl_emb_dim=old_args.ssl_emb_dim,
|
57 |
+
)
|
58 |
+
if cuda_available:
|
59 |
+
slip_model.cuda()
|
60 |
+
slip_model.load_state_dict(state_dict, strict=True)
|
61 |
+
slip_model = get_model(slip_model)
|
62 |
+
tokenizer = SimpleTokenizer()
|
63 |
+
del ckpt
|
64 |
+
del state_dict
|
65 |
+
# Load CLIP model from HuggingFace
|
66 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
67 |
+
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
68 |
+
# Load images' descriptions and embeddings
|
69 |
+
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
|
70 |
+
embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
|
71 |
+
slip_embeddings = {
|
72 |
+
0: np.load("embeddings_slip_large.npy"),
|
73 |
+
1: np.load("embeddings2_slip_large.npy"),
|
74 |
+
}
|
75 |
+
for k in [0, 1]:
|
76 |
+
embeddings[k] = np.divide(
|
77 |
+
embeddings[k], np.sqrt(np.sum(embeddings[k] ** 2, axis=1, keepdims=True))
|
78 |
+
)
|
79 |
+
return model, processor, slip_model, tokenizer, df, embeddings, slip_embeddings
|
80 |
+
|
81 |
+
|
82 |
+
model, processor, slip_model, tokenizer, df, embeddings, slip_embeddings = load()
|
83 |
+
|
84 |
+
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
|
85 |
+
|
86 |
+
|
87 |
+
def get_html(url_list, url_list_slip, height=150):
|
88 |
+
html = "<div style='display: flex; flex-wrap: wrap; justify-content: space-evenly;'>"
|
89 |
+
html += "<span style='margin-top: 20px; max-width: 1200px; display: flex; align-content: flex-start; flex-wrap: wrap; justify-content: space-evenly; width: 50%'>"
|
90 |
+
html += "<div style='width: 100%; text-align: center;'><b>CLIP</b> (<a href='https://arxiv.org/abs/2103.00020'>Arxiv</a>, <a href='https://github.com/openai/CLIP'>GitHub</a>) from OpenAI</div>"
|
91 |
+
for url, title, link in url_list:
|
92 |
+
html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>"
|
93 |
+
if len(link) > 0:
|
94 |
+
html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
|
95 |
+
html = html + html2
|
96 |
+
html += "</span>"
|
97 |
+
html += "<span style='margin-top: 20px; max-width: 1200px; display: flex; align-content: flex-start; flex-wrap: wrap; justify-content: space-evenly; width: 50%; border-left: solid; border-color: #ffc423; border-width: thin;'>"
|
98 |
+
html += "<div style='width: 100%; text-align: center;'><b>SLIP</b> (<a href='https://arxiv.org/abs/2112.12750'>Arxiv</a>, <a href='https://github.com/facebookresearch/SLIP'>GitHub</a>) from Meta AI</div>"
|
99 |
+
for url, title, link in url_list_slip:
|
100 |
+
html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>"
|
101 |
+
if len(link) > 0:
|
102 |
+
html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
|
103 |
+
html = html + html2
|
104 |
+
html += "</span></div>"
|
105 |
+
return html
|
106 |
+
|
107 |
+
def compute_text_embeddings(list_of_strings):
|
108 |
+
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
|
109 |
+
return model.get_text_features(**inputs)
|
110 |
+
|
111 |
+
def compute_text_embeddings_slip(list_of_strings):
|
112 |
+
texts = tokenizer(list_of_strings)
|
113 |
+
if cuda_available:
|
114 |
+
texts = texts.cuda(non_blocking=True)
|
115 |
+
texts = texts.view(-1, 77).contiguous()
|
116 |
+
return slip_model.encode_text(texts)
|
117 |
+
|
118 |
+
def image_search(query, corpus, n_results=24):
|
119 |
+
text_embeddings = compute_text_embeddings([query]).detach().numpy()
|
120 |
+
text_embeddings_slip = compute_text_embeddings_slip([query]).detach().numpy()
|
121 |
+
k = 0 if corpus == "Unsplash" else 1
|
122 |
+
results = np.argsort((embeddings[k] @ text_embeddings.T)[:, 0])[
|
123 |
+
-1 : -n_results - 1 : -1
|
124 |
+
]
|
125 |
+
results_slip = np.argsort((slip_embeddings[k] @ text_embeddings_slip.T)[:, 0])[
|
126 |
+
-1 : -n_results - 1 : -1
|
127 |
+
]
|
128 |
+
return (
|
129 |
+
[
|
130 |
+
(
|
131 |
+
df[k].iloc[i]["path"],
|
132 |
+
df[k].iloc[i]["tooltip"] + source[k],
|
133 |
+
df[k].iloc[i]["link"],
|
134 |
+
)
|
135 |
+
for i in results
|
136 |
+
],
|
137 |
+
[
|
138 |
+
(
|
139 |
+
df[k].iloc[i]["path"],
|
140 |
+
df[k].iloc[i]["tooltip"] + source[k],
|
141 |
+
df[k].iloc[i]["link"],
|
142 |
+
)
|
143 |
+
for i in results_slip
|
144 |
+
],
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
description = """
|
149 |
+
# Comparing CLIP and SLIP side by side
|
150 |
+
|
151 |
+
**Enter your query and hit enter**
|
152 |
+
|
153 |
+
CLIP and SLIP are ML models that encode images and texts as vectors so that the vectors of an image and its caption are similar. They can notably be used for zero-shot image classification, text-based image retrieval or image generation.
|
154 |
+
|
155 |
+
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, Meta AI's [SLIP](https://github.com/facebookresearch/SLIP) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
|
156 |
+
"""
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
st.markdown(
|
161 |
+
"""
|
162 |
+
<style>
|
163 |
+
.block-container{
|
164 |
+
max-width: 1200px;
|
165 |
+
}
|
166 |
+
div.row-widget.stRadio > div{
|
167 |
+
flex-direction:row;
|
168 |
+
display: flex;
|
169 |
+
justify-content: center;
|
170 |
+
}
|
171 |
+
div.row-widget.stRadio > div > label{
|
172 |
+
margin-left: 5px;
|
173 |
+
margin-right: 5px;
|
174 |
+
}
|
175 |
+
section.main>div:first-child {
|
176 |
+
padding-top: 0px;
|
177 |
+
}
|
178 |
+
section:not(.main)>div:first-child {
|
179 |
+
padding-top: 30px;
|
180 |
+
}
|
181 |
+
div.reportview-container > section:first-child{
|
182 |
+
max-width: 320px;
|
183 |
+
}
|
184 |
+
#MainMenu {
|
185 |
+
visibility: hidden;
|
186 |
+
}
|
187 |
+
footer {
|
188 |
+
visibility: hidden;
|
189 |
+
}
|
190 |
+
</style>""",
|
191 |
+
unsafe_allow_html=True,
|
192 |
+
)
|
193 |
+
st.sidebar.markdown(description)
|
194 |
+
_, c, _ = st.columns((1, 3, 1))
|
195 |
+
query = c.text_input("", value="clouds at sunset")
|
196 |
+
corpus = st.radio("", ["Unsplash", "Movies"])
|
197 |
+
if len(query) > 0:
|
198 |
+
results, results_slip = image_search(query, corpus)
|
199 |
+
st.markdown(get_html(results, results_slip), unsafe_allow_html=True)
|
bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
data.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data2.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
embeddings.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f8c171e32276739be6b020592edc8a2c06e029ff6505a9d1d4efe3cafa073bd
|
3 |
+
size 51200128
|
embeddings2.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9664e980f31e81c4a34e07833539fea32795d83a4262c9828ceae445fa2e412a
|
3 |
+
size 16732288
|
embeddings2_slip_large.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5632813e4a27062f2a7bc3f2db23ac3f62d946b53d3b9144c1d5c7e8f9865f90
|
3 |
+
size 16732288
|
embeddings_slip_large.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:98fd7411e6874bfd703c134470b9e5a82c0a7a403bb1cf1cac5851dc3871498f
|
3 |
+
size 51200128
|
losses.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
import utils
|
11 |
+
|
12 |
+
|
13 |
+
class CLIPLoss(nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super().__init__()
|
16 |
+
self.labels = None
|
17 |
+
self.last_local_batch_size = None
|
18 |
+
|
19 |
+
def forward(self, outputs):
|
20 |
+
image_embed = outputs['image_embed']
|
21 |
+
text_embed = outputs['text_embed']
|
22 |
+
logit_scale = outputs['logit_scale']
|
23 |
+
local_batch_size = image_embed.size(0)
|
24 |
+
|
25 |
+
if local_batch_size != self.last_local_batch_size:
|
26 |
+
self.labels = local_batch_size * utils.get_rank() + torch.arange(
|
27 |
+
local_batch_size, device=image_embed.device
|
28 |
+
)
|
29 |
+
self.last_local_batch_size = local_batch_size
|
30 |
+
|
31 |
+
# normalized features
|
32 |
+
image_embed = F.normalize(image_embed, dim=-1, p=2)
|
33 |
+
text_embed = F.normalize(text_embed, dim=-1, p=2)
|
34 |
+
|
35 |
+
# gather features from all GPUs
|
36 |
+
image_embed_all, text_embed_all = \
|
37 |
+
utils.all_gather_batch([image_embed, text_embed])
|
38 |
+
|
39 |
+
# cosine similarity as logits
|
40 |
+
logits_per_image = logit_scale * image_embed @ text_embed_all.t()
|
41 |
+
logits_per_text = logit_scale * text_embed @ image_embed_all.t()
|
42 |
+
|
43 |
+
loss = (F.cross_entropy(logits_per_image, self.labels) + \
|
44 |
+
F.cross_entropy(logits_per_text, self.labels)) / 2
|
45 |
+
|
46 |
+
# compute accuracy
|
47 |
+
with torch.no_grad():
|
48 |
+
pred = torch.argmax(logits_per_image, dim=-1)
|
49 |
+
correct = pred.eq(self.labels).sum()
|
50 |
+
acc = 100 * correct / local_batch_size
|
51 |
+
|
52 |
+
return {'loss': loss, 'clip_loss': loss, 'clip_acc': acc}
|
53 |
+
|
54 |
+
|
55 |
+
class SIMCLRLoss(nn.Module):
|
56 |
+
"""
|
57 |
+
This is the SimCLR loss in https://arxiv.org/abs/2002.05709
|
58 |
+
The embedding vectors are assumed to have size (2 x batch_size, embedding_dim) and
|
59 |
+
the memory layout that can be reshaped into shape (2, batch_size, embedding_dim).
|
60 |
+
This memory layout is consistent with the SimCLR collator in
|
61 |
+
https://github.com/facebookresearch/vissl/blob/master/vissl/data/collators/simclr_collator.py
|
62 |
+
Config params:
|
63 |
+
temperature (float): the temperature to be applied on the logits
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(self, temperature=0.1):
|
67 |
+
super().__init__()
|
68 |
+
self.tau = temperature
|
69 |
+
self.labels = None
|
70 |
+
self.masks = None
|
71 |
+
self.last_local_batch_size = None
|
72 |
+
|
73 |
+
def forward(self, outputs):
|
74 |
+
q_a = outputs['aug1_embed']
|
75 |
+
q_b = outputs['aug2_embed']
|
76 |
+
|
77 |
+
q_a = F.normalize(q_a, dim=-1, p=2)
|
78 |
+
q_b = F.normalize(q_b, dim=-1, p=2)
|
79 |
+
|
80 |
+
local_batch_size = q_a.size(0)
|
81 |
+
|
82 |
+
k_a, k_b = utils.all_gather_batch_with_grad([q_a, q_b])
|
83 |
+
|
84 |
+
if local_batch_size != self.last_local_batch_size:
|
85 |
+
self.labels = local_batch_size * utils.get_rank() + torch.arange(
|
86 |
+
local_batch_size, device=q_a.device
|
87 |
+
)
|
88 |
+
total_batch_size = local_batch_size * utils.get_world_size()
|
89 |
+
self.masks = F.one_hot(self.labels, total_batch_size) * 1e9
|
90 |
+
self.last_local_batch_size = local_batch_size
|
91 |
+
|
92 |
+
logits_aa = torch.matmul(q_a, k_a.transpose(0, 1)) / self.tau
|
93 |
+
logits_aa = logits_aa - self.masks
|
94 |
+
logits_bb = torch.matmul(q_b, k_b.transpose(0, 1)) / self.tau
|
95 |
+
logits_bb = logits_bb - self.masks
|
96 |
+
logits_ab = torch.matmul(q_a, k_b.transpose(0, 1)) / self.tau
|
97 |
+
logits_ba = torch.matmul(q_b, k_a.transpose(0, 1)) / self.tau
|
98 |
+
|
99 |
+
loss_a = F.cross_entropy(torch.cat([logits_ab, logits_aa], dim=1), self.labels)
|
100 |
+
loss_b = F.cross_entropy(torch.cat([logits_ba, logits_bb], dim=1), self.labels)
|
101 |
+
loss = (loss_a + loss_b) / 2 # divide by 2 to average over all samples
|
102 |
+
|
103 |
+
# compute accuracy
|
104 |
+
with torch.no_grad():
|
105 |
+
pred = torch.argmax(torch.cat([logits_ab, logits_aa], dim=1), dim=-1)
|
106 |
+
correct = pred.eq(self.labels).sum()
|
107 |
+
acc = 100 * correct / local_batch_size
|
108 |
+
|
109 |
+
return {'loss': loss, 'ssl_loss': loss, 'ssl_acc': acc}
|
110 |
+
|
111 |
+
|
112 |
+
class SLIPLoss(nn.Module):
|
113 |
+
def __init__(self, ssl_loss, ssl_scale):
|
114 |
+
super().__init__()
|
115 |
+
self.clip_loss = CLIPLoss()
|
116 |
+
self.ssl_loss = ssl_loss
|
117 |
+
self.ssl_scale = ssl_scale
|
118 |
+
|
119 |
+
def forward(self, outputs):
|
120 |
+
clip_loss_dict = self.clip_loss(outputs)
|
121 |
+
clip_loss = clip_loss_dict['clip_loss']
|
122 |
+
clip_acc = clip_loss_dict['clip_acc']
|
123 |
+
|
124 |
+
ssl_loss_dict = self.ssl_loss(outputs)
|
125 |
+
ssl_loss = ssl_loss_dict['ssl_loss']
|
126 |
+
ssl_acc = ssl_loss_dict['ssl_acc']
|
127 |
+
|
128 |
+
return {'loss': clip_loss + self.ssl_scale * ssl_loss,
|
129 |
+
'clip_loss': clip_loss,
|
130 |
+
'clip_acc': clip_acc,
|
131 |
+
'ssl_loss': ssl_loss,
|
132 |
+
'ssl_acc': ssl_acc}
|
models.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from github.com/openai/CLIP
|
8 |
+
from collections import OrderedDict
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import timm
|
12 |
+
import torch
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
import losses
|
16 |
+
|
17 |
+
|
18 |
+
class LayerNorm(nn.LayerNorm):
|
19 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
20 |
+
|
21 |
+
def forward(self, x: torch.Tensor):
|
22 |
+
orig_type = x.dtype
|
23 |
+
ret = super().forward(x.type(torch.float32))
|
24 |
+
return ret.type(orig_type)
|
25 |
+
|
26 |
+
|
27 |
+
class QuickGELU(nn.Module):
|
28 |
+
def forward(self, x: torch.Tensor):
|
29 |
+
return x * torch.sigmoid(1.702 * x)
|
30 |
+
|
31 |
+
|
32 |
+
class ResidualAttentionBlock(nn.Module):
|
33 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
37 |
+
self.ln_1 = LayerNorm(d_model)
|
38 |
+
self.mlp = nn.Sequential(OrderedDict([
|
39 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
40 |
+
("gelu", QuickGELU()),
|
41 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
42 |
+
]))
|
43 |
+
self.ln_2 = LayerNorm(d_model)
|
44 |
+
self.attn_mask = attn_mask
|
45 |
+
|
46 |
+
def attention(self, x: torch.Tensor):
|
47 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
48 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
49 |
+
|
50 |
+
def forward(self, x: torch.Tensor):
|
51 |
+
x = x + self.attention(self.ln_1(x))
|
52 |
+
x = x + self.mlp(self.ln_2(x))
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
class Transformer(nn.Module):
|
57 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
58 |
+
super().__init__()
|
59 |
+
self.width = width
|
60 |
+
self.layers = layers
|
61 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
62 |
+
|
63 |
+
def forward(self, x: torch.Tensor):
|
64 |
+
return self.resblocks(x)
|
65 |
+
|
66 |
+
|
67 |
+
class CLIP(nn.Module):
|
68 |
+
def __init__(self,
|
69 |
+
embed_dim: int,
|
70 |
+
# vision
|
71 |
+
vision_width: int,
|
72 |
+
vision_model: nn.Module,
|
73 |
+
# text
|
74 |
+
context_length: int,
|
75 |
+
vocab_size: int,
|
76 |
+
transformer_width: int,
|
77 |
+
transformer_heads: int,
|
78 |
+
transformer_layers: int,
|
79 |
+
**kwargs,
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
|
83 |
+
self.context_length = context_length
|
84 |
+
self.vision_width = vision_width
|
85 |
+
|
86 |
+
self.visual = vision_model
|
87 |
+
|
88 |
+
self.transformer = Transformer(
|
89 |
+
width=transformer_width,
|
90 |
+
layers=transformer_layers,
|
91 |
+
heads=transformer_heads,
|
92 |
+
attn_mask=self.build_attention_mask(),
|
93 |
+
)
|
94 |
+
|
95 |
+
self.vocab_size = vocab_size
|
96 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
97 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
98 |
+
self.ln_final = LayerNorm(transformer_width)
|
99 |
+
|
100 |
+
self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
|
101 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
102 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
103 |
+
|
104 |
+
self.initialize_parameters()
|
105 |
+
|
106 |
+
def initialize_parameters(self):
|
107 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
108 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
109 |
+
|
110 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
111 |
+
attn_std = self.transformer.width ** -0.5
|
112 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
113 |
+
for block in self.transformer.resblocks:
|
114 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
115 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
116 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
117 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
118 |
+
|
119 |
+
nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)
|
120 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
121 |
+
|
122 |
+
def build_attention_mask(self):
|
123 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
124 |
+
# pytorch uses additive attention mask; fill with -inf
|
125 |
+
mask = torch.empty(self.context_length, self.context_length)
|
126 |
+
mask.fill_(float("-inf"))
|
127 |
+
mask.triu_(1) # zero out the lower diagonal
|
128 |
+
return mask
|
129 |
+
|
130 |
+
def encode_image(self, image):
|
131 |
+
x = self.visual(image)
|
132 |
+
x = x @ self.image_projection
|
133 |
+
|
134 |
+
return x
|
135 |
+
|
136 |
+
def encode_text(self, text):
|
137 |
+
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
|
138 |
+
x = x + self.positional_embedding
|
139 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
140 |
+
x = self.transformer(x)
|
141 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
142 |
+
x = self.ln_final(x)
|
143 |
+
|
144 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
145 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
146 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
147 |
+
|
148 |
+
return x
|
149 |
+
|
150 |
+
def forward(self, image, text):
|
151 |
+
image_embed = self.encode_image(image)
|
152 |
+
text_embed = self.encode_text(text)
|
153 |
+
|
154 |
+
return {'image_embed': image_embed,
|
155 |
+
'text_embed': text_embed,
|
156 |
+
'logit_scale': self.logit_scale.exp()}
|
157 |
+
|
158 |
+
|
159 |
+
class SIMCLR(nn.Module):
|
160 |
+
def __init__(self,
|
161 |
+
# vision
|
162 |
+
vision_width: int,
|
163 |
+
vision_model: nn.Module,
|
164 |
+
# ssl
|
165 |
+
ssl_mlp_dim: int,
|
166 |
+
ssl_emb_dim: int,
|
167 |
+
**kwargs,
|
168 |
+
):
|
169 |
+
super().__init__()
|
170 |
+
|
171 |
+
self.vision_width = vision_width
|
172 |
+
self.visual = vision_model
|
173 |
+
|
174 |
+
self.image_mlp = self._build_mlp(in_dim=vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim)
|
175 |
+
|
176 |
+
def _build_mlp(self, in_dim, mlp_dim, out_dim):
|
177 |
+
return nn.Sequential(OrderedDict([
|
178 |
+
("layer1", nn.Linear(in_dim, mlp_dim)),
|
179 |
+
("bn1", nn.SyncBatchNorm(mlp_dim)),
|
180 |
+
("relu1", nn.ReLU(inplace=True)),
|
181 |
+
("layer2", nn.Linear(mlp_dim, mlp_dim)),
|
182 |
+
("bn2", nn.SyncBatchNorm(mlp_dim)),
|
183 |
+
("relu2", nn.ReLU(inplace=True)),
|
184 |
+
("layer3", nn.Linear(mlp_dim, out_dim)),
|
185 |
+
]))
|
186 |
+
|
187 |
+
def encode_image(self, image):
|
188 |
+
x = self.visual(image)
|
189 |
+
|
190 |
+
return x
|
191 |
+
|
192 |
+
def forward(self, aug1, aug2):
|
193 |
+
h1 = self.visual(aug1)
|
194 |
+
h2 = self.visual(aug2)
|
195 |
+
|
196 |
+
aug1_embed = self.image_mlp(h1)
|
197 |
+
aug2_embed = self.image_mlp(h2)
|
198 |
+
|
199 |
+
return {'aug1_embed': aug1_embed,
|
200 |
+
'aug2_embed': aug2_embed}
|
201 |
+
|
202 |
+
|
203 |
+
class SLIP(CLIP):
|
204 |
+
def __init__(self,
|
205 |
+
ssl_mlp_dim: int,
|
206 |
+
ssl_emb_dim: int,
|
207 |
+
**kwargs,
|
208 |
+
):
|
209 |
+
super().__init__(**kwargs)
|
210 |
+
|
211 |
+
self.image_mlp = self._build_mlp(in_dim=self.vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim)
|
212 |
+
|
213 |
+
def _build_mlp(self, in_dim, mlp_dim, out_dim):
|
214 |
+
return nn.Sequential(OrderedDict([
|
215 |
+
("layer1", nn.Linear(in_dim, mlp_dim)),
|
216 |
+
("bn1", nn.SyncBatchNorm(mlp_dim)),
|
217 |
+
("relu1", nn.ReLU(inplace=True)),
|
218 |
+
("layer2", nn.Linear(mlp_dim, mlp_dim)),
|
219 |
+
("bn2", nn.SyncBatchNorm(mlp_dim)),
|
220 |
+
("relu2", nn.ReLU(inplace=True)),
|
221 |
+
("layer3", nn.Linear(mlp_dim, out_dim)),
|
222 |
+
]))
|
223 |
+
|
224 |
+
def forward(self, image, text, aug1, aug2):
|
225 |
+
aug1_embed = self.image_mlp(self.visual(aug1))
|
226 |
+
aug2_embed = self.image_mlp(self.visual(aug2))
|
227 |
+
|
228 |
+
image_embed = self.encode_image(image)
|
229 |
+
text_embed = self.encode_text(text)
|
230 |
+
|
231 |
+
return {'image_embed': image_embed,
|
232 |
+
'text_embed': text_embed,
|
233 |
+
'logit_scale': self.logit_scale.exp(),
|
234 |
+
'aug1_embed': aug1_embed,
|
235 |
+
'aug2_embed': aug2_embed}
|
236 |
+
|
237 |
+
|
238 |
+
def get_loss(model, ssl_temp, ssl_scale):
|
239 |
+
if model.startswith('SLIP'):
|
240 |
+
ssl_loss = losses.SIMCLRLoss(temperature=ssl_temp)
|
241 |
+
return losses.SLIPLoss(ssl_loss, ssl_scale)
|
242 |
+
if model.startswith('CLIP'):
|
243 |
+
return losses.CLIPLoss()
|
244 |
+
if model.startswith('SIMCLR'):
|
245 |
+
return losses.SIMCLRLoss(temperature=ssl_temp)
|
246 |
+
|
247 |
+
|
248 |
+
def get_metric_names(model):
|
249 |
+
if model.startswith('SLIP'):
|
250 |
+
return ['loss', 'clip_loss', 'ssl_loss', 'clip_acc', 'ssl_acc']
|
251 |
+
elif model.startswith('CLIP'):
|
252 |
+
return ['loss', 'clip_loss', 'clip_acc']
|
253 |
+
else:
|
254 |
+
return ['loss', 'ssl_loss', 'ssl_acc']
|
255 |
+
|
256 |
+
|
257 |
+
@timm.models.registry.register_model
|
258 |
+
def vit_small_mocov3_patch16_224(**kwargs):
|
259 |
+
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12, **kwargs)
|
260 |
+
model = timm.models.vision_transformer._create_vision_transformer('vit_small_patch16_224', **model_kwargs)
|
261 |
+
|
262 |
+
return model
|
263 |
+
|
264 |
+
|
265 |
+
def CLIP_VITS16(**kwargs):
|
266 |
+
vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0)
|
267 |
+
model = CLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408,
|
268 |
+
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
|
269 |
+
|
270 |
+
return model
|
271 |
+
|
272 |
+
|
273 |
+
def SIMCLR_VITS16(**kwargs):
|
274 |
+
vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0)
|
275 |
+
model = SIMCLR(vision_width=384, vision_model=vision_model, **kwargs)
|
276 |
+
|
277 |
+
return model
|
278 |
+
|
279 |
+
|
280 |
+
def SLIP_VITS16(**kwargs):
|
281 |
+
vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0)
|
282 |
+
model = SLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408,
|
283 |
+
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
|
284 |
+
|
285 |
+
return model
|
286 |
+
|
287 |
+
|
288 |
+
def CLIP_VITB16(**kwargs):
|
289 |
+
vision_model = timm.create_model('vit_base_patch16_224', num_classes=0)
|
290 |
+
model = CLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408,
|
291 |
+
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
|
292 |
+
|
293 |
+
return model
|
294 |
+
|
295 |
+
|
296 |
+
def SIMCLR_VITB16(**kwargs):
|
297 |
+
vision_model = timm.create_model('vit_base_patch16_224', num_classes=0)
|
298 |
+
model = SIMCLR(vision_width=768, vision_model=vision_model, **kwargs)
|
299 |
+
|
300 |
+
return model
|
301 |
+
|
302 |
+
|
303 |
+
def SLIP_VITB16(**kwargs):
|
304 |
+
vision_model = timm.create_model('vit_base_patch16_224', num_classes=0)
|
305 |
+
model = SLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408,
|
306 |
+
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
|
307 |
+
|
308 |
+
return model
|
309 |
+
|
310 |
+
|
311 |
+
def CLIP_VITL16(**kwargs):
|
312 |
+
vision_model = timm.create_model('vit_large_patch16_224', num_classes=0)
|
313 |
+
model = CLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408,
|
314 |
+
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
|
315 |
+
|
316 |
+
return model
|
317 |
+
|
318 |
+
|
319 |
+
def SIMCLR_VITL16(**kwargs):
|
320 |
+
vision_model = timm.create_model('vit_large_patch16_224', num_classes=0)
|
321 |
+
model = SIMCLR(vision_width=1024, vision_model=vision_model, **kwargs)
|
322 |
+
|
323 |
+
return model
|
324 |
+
|
325 |
+
|
326 |
+
def SLIP_VITL16(**kwargs):
|
327 |
+
vision_model = timm.create_model('vit_large_patch16_224', num_classes=0)
|
328 |
+
model = SLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408,
|
329 |
+
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs)
|
330 |
+
|
331 |
+
return model
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torchvision
|
2 |
+
transformers
|
3 |
+
numpy
|
4 |
+
pandas
|
5 |
+
timm
|
6 |
+
ftfy
|
tokenizer.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from github.com/openai/CLIP
|
8 |
+
import gzip
|
9 |
+
import html
|
10 |
+
import os
|
11 |
+
from functools import lru_cache
|
12 |
+
|
13 |
+
import ftfy
|
14 |
+
import regex as re
|
15 |
+
import torch
|
16 |
+
|
17 |
+
|
18 |
+
@lru_cache()
|
19 |
+
def default_bpe():
|
20 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
21 |
+
|
22 |
+
|
23 |
+
@lru_cache()
|
24 |
+
def bytes_to_unicode():
|
25 |
+
"""
|
26 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
27 |
+
The reversible bpe codes work on unicode strings.
|
28 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
29 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
30 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
31 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
32 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
33 |
+
"""
|
34 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
35 |
+
cs = bs[:]
|
36 |
+
n = 0
|
37 |
+
for b in range(2**8):
|
38 |
+
if b not in bs:
|
39 |
+
bs.append(b)
|
40 |
+
cs.append(2**8+n)
|
41 |
+
n += 1
|
42 |
+
cs = [chr(n) for n in cs]
|
43 |
+
return dict(zip(bs, cs))
|
44 |
+
|
45 |
+
|
46 |
+
def get_pairs(word):
|
47 |
+
"""Return set of symbol pairs in a word.
|
48 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
49 |
+
"""
|
50 |
+
pairs = set()
|
51 |
+
prev_char = word[0]
|
52 |
+
for char in word[1:]:
|
53 |
+
pairs.add((prev_char, char))
|
54 |
+
prev_char = char
|
55 |
+
return pairs
|
56 |
+
|
57 |
+
|
58 |
+
def basic_clean(text):
|
59 |
+
text = ftfy.fix_text(text)
|
60 |
+
text = html.unescape(html.unescape(text))
|
61 |
+
return text.strip()
|
62 |
+
|
63 |
+
|
64 |
+
def whitespace_clean(text):
|
65 |
+
text = re.sub(r'\s+', ' ', text)
|
66 |
+
text = text.strip()
|
67 |
+
return text
|
68 |
+
|
69 |
+
|
70 |
+
class SimpleTokenizer(object):
|
71 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
72 |
+
self.byte_encoder = bytes_to_unicode()
|
73 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
74 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
75 |
+
merges = merges[1:49152-256-2+1]
|
76 |
+
merges = [tuple(merge.split()) for merge in merges]
|
77 |
+
vocab = list(bytes_to_unicode().values())
|
78 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
79 |
+
for merge in merges:
|
80 |
+
vocab.append(''.join(merge))
|
81 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
82 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
83 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
84 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
85 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
86 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
87 |
+
|
88 |
+
def bpe(self, token):
|
89 |
+
if token in self.cache:
|
90 |
+
return self.cache[token]
|
91 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
92 |
+
pairs = get_pairs(word)
|
93 |
+
|
94 |
+
if not pairs:
|
95 |
+
return token+'</w>'
|
96 |
+
|
97 |
+
while True:
|
98 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
99 |
+
if bigram not in self.bpe_ranks:
|
100 |
+
break
|
101 |
+
first, second = bigram
|
102 |
+
new_word = []
|
103 |
+
i = 0
|
104 |
+
while i < len(word):
|
105 |
+
try:
|
106 |
+
j = word.index(first, i)
|
107 |
+
new_word.extend(word[i:j])
|
108 |
+
i = j
|
109 |
+
except:
|
110 |
+
new_word.extend(word[i:])
|
111 |
+
break
|
112 |
+
|
113 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
114 |
+
new_word.append(first+second)
|
115 |
+
i += 2
|
116 |
+
else:
|
117 |
+
new_word.append(word[i])
|
118 |
+
i += 1
|
119 |
+
new_word = tuple(new_word)
|
120 |
+
word = new_word
|
121 |
+
if len(word) == 1:
|
122 |
+
break
|
123 |
+
else:
|
124 |
+
pairs = get_pairs(word)
|
125 |
+
word = ' '.join(word)
|
126 |
+
self.cache[token] = word
|
127 |
+
return word
|
128 |
+
|
129 |
+
def encode(self, text):
|
130 |
+
bpe_tokens = []
|
131 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
132 |
+
for token in re.findall(self.pat, text):
|
133 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
134 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
135 |
+
return bpe_tokens
|
136 |
+
|
137 |
+
def decode(self, tokens):
|
138 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
139 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
140 |
+
return text
|
141 |
+
|
142 |
+
def __call__(self, texts, context_length=77):
|
143 |
+
if isinstance(texts, str):
|
144 |
+
texts = [texts]
|
145 |
+
|
146 |
+
sot_token = self.encoder["<|startoftext|>"]
|
147 |
+
eot_token = self.encoder["<|endoftext|>"]
|
148 |
+
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
|
149 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
150 |
+
|
151 |
+
for i, tokens in enumerate(all_tokens):
|
152 |
+
tokens = tokens[:context_length]
|
153 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
154 |
+
|
155 |
+
if len(result) == 1:
|
156 |
+
return result[0]
|
157 |
+
return result
|
utils.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import shutil
|
10 |
+
import torch
|
11 |
+
import torch.distributed as dist
|
12 |
+
import torch.autograd as autograd
|
13 |
+
|
14 |
+
from PIL import ImageFilter
|
15 |
+
|
16 |
+
|
17 |
+
def get_model(model):
|
18 |
+
if isinstance(model, torch.nn.DataParallel) \
|
19 |
+
or isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
20 |
+
return model.module
|
21 |
+
else:
|
22 |
+
return model
|
23 |
+
|
24 |
+
|
25 |
+
def setup_for_distributed(is_master):
|
26 |
+
"""
|
27 |
+
This function disables printing when not in master process
|
28 |
+
"""
|
29 |
+
import builtins as __builtin__
|
30 |
+
builtin_print = __builtin__.print
|
31 |
+
|
32 |
+
def print(*args, **kwargs):
|
33 |
+
force = kwargs.pop('force', False)
|
34 |
+
if is_master or force:
|
35 |
+
builtin_print(*args, **kwargs)
|
36 |
+
|
37 |
+
__builtin__.print = print
|
38 |
+
|
39 |
+
|
40 |
+
def is_dist_avail_and_initialized():
|
41 |
+
if not dist.is_available():
|
42 |
+
return False
|
43 |
+
if not dist.is_initialized():
|
44 |
+
return False
|
45 |
+
return True
|
46 |
+
|
47 |
+
|
48 |
+
def get_world_size():
|
49 |
+
if not is_dist_avail_and_initialized():
|
50 |
+
return 1
|
51 |
+
return dist.get_world_size()
|
52 |
+
|
53 |
+
|
54 |
+
def get_rank():
|
55 |
+
if not is_dist_avail_and_initialized():
|
56 |
+
return 0
|
57 |
+
return dist.get_rank()
|
58 |
+
|
59 |
+
|
60 |
+
def is_main_process():
|
61 |
+
return get_rank() == 0
|
62 |
+
|
63 |
+
|
64 |
+
def save_on_master(state, is_best, output_dir):
|
65 |
+
if is_main_process():
|
66 |
+
ckpt_path = f'{output_dir}/checkpoint.pt'
|
67 |
+
best_path = f'{output_dir}/checkpoint_best.pt'
|
68 |
+
torch.save(state, ckpt_path)
|
69 |
+
if is_best:
|
70 |
+
shutil.copyfile(ckpt_path, best_path)
|
71 |
+
|
72 |
+
|
73 |
+
def init_distributed_mode(args):
|
74 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
75 |
+
args.rank = int(os.environ["RANK"])
|
76 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
77 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
78 |
+
elif 'SLURM_PROCID' in os.environ:
|
79 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
80 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
81 |
+
else:
|
82 |
+
print('Not using distributed mode')
|
83 |
+
args.distributed = False
|
84 |
+
return
|
85 |
+
|
86 |
+
args.distributed = True
|
87 |
+
|
88 |
+
torch.cuda.set_device(args.gpu)
|
89 |
+
args.dist_backend = 'nccl'
|
90 |
+
print('| distributed init (rank {}): {}'.format(
|
91 |
+
args.rank, args.dist_url), flush=True)
|
92 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
93 |
+
world_size=args.world_size, rank=args.rank)
|
94 |
+
torch.distributed.barrier()
|
95 |
+
setup_for_distributed(args.rank == 0)
|
96 |
+
|
97 |
+
|
98 |
+
def scaled_all_reduce(tensors, is_scale=True):
|
99 |
+
"""Performs the scaled all_reduce operation on the provided tensors.
|
100 |
+
The input tensors are modified in-place. Currently supports only the sum
|
101 |
+
reduction operator. The reduced values are scaled by the inverse size of the
|
102 |
+
world size.
|
103 |
+
"""
|
104 |
+
world_size = get_world_size()
|
105 |
+
# There is no need for reduction in the single-proc case
|
106 |
+
if world_size == 1:
|
107 |
+
return tensors
|
108 |
+
# Queue the reductions
|
109 |
+
reductions = []
|
110 |
+
for tensor in tensors:
|
111 |
+
reduction = dist.all_reduce(tensor, async_op=True)
|
112 |
+
reductions.append(reduction)
|
113 |
+
# Wait for reductions to finish
|
114 |
+
for reduction in reductions:
|
115 |
+
reduction.wait()
|
116 |
+
# Scale the results
|
117 |
+
if is_scale:
|
118 |
+
for tensor in tensors:
|
119 |
+
tensor.mul_(1.0 / world_size)
|
120 |
+
return tensors
|
121 |
+
|
122 |
+
|
123 |
+
def all_gather_batch(tensors):
|
124 |
+
"""
|
125 |
+
Performs all_gather operation on the provided tensors.
|
126 |
+
"""
|
127 |
+
# Queue the gathered tensors
|
128 |
+
world_size = get_world_size()
|
129 |
+
# There is no need for reduction in the single-proc case
|
130 |
+
if world_size == 1:
|
131 |
+
return tensors
|
132 |
+
tensor_list = []
|
133 |
+
output_tensor = []
|
134 |
+
for tensor in tensors:
|
135 |
+
tensor_all = [torch.ones_like(tensor) for _ in range(world_size)]
|
136 |
+
dist.all_gather(
|
137 |
+
tensor_all,
|
138 |
+
tensor,
|
139 |
+
async_op=False # performance opt
|
140 |
+
)
|
141 |
+
|
142 |
+
tensor_list.append(tensor_all)
|
143 |
+
|
144 |
+
for tensor_all in tensor_list:
|
145 |
+
output_tensor.append(torch.cat(tensor_all, dim=0))
|
146 |
+
return output_tensor
|
147 |
+
|
148 |
+
|
149 |
+
class GatherLayer(autograd.Function):
|
150 |
+
"""
|
151 |
+
Gather tensors from all workers with support for backward propagation:
|
152 |
+
This implementation does not cut the gradients as torch.distributed.all_gather does.
|
153 |
+
"""
|
154 |
+
|
155 |
+
@staticmethod
|
156 |
+
def forward(ctx, x):
|
157 |
+
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
|
158 |
+
dist.all_gather(output, x)
|
159 |
+
return tuple(output)
|
160 |
+
|
161 |
+
@staticmethod
|
162 |
+
def backward(ctx, *grads):
|
163 |
+
all_gradients = torch.stack(grads)
|
164 |
+
dist.all_reduce(all_gradients)
|
165 |
+
return all_gradients[dist.get_rank()]
|
166 |
+
|
167 |
+
|
168 |
+
def all_gather_batch_with_grad(tensors):
|
169 |
+
"""
|
170 |
+
Performs all_gather operation on the provided tensors.
|
171 |
+
Graph remains connected for backward grad computation.
|
172 |
+
"""
|
173 |
+
# Queue the gathered tensors
|
174 |
+
world_size = get_world_size()
|
175 |
+
# There is no need for reduction in the single-proc case
|
176 |
+
if world_size == 1:
|
177 |
+
return tensors
|
178 |
+
tensor_list = []
|
179 |
+
output_tensor = []
|
180 |
+
|
181 |
+
for tensor in tensors:
|
182 |
+
tensor_all = GatherLayer.apply(tensor)
|
183 |
+
tensor_list.append(tensor_all)
|
184 |
+
|
185 |
+
for tensor_all in tensor_list:
|
186 |
+
output_tensor.append(torch.cat(tensor_all, dim=0))
|
187 |
+
return output_tensor
|
188 |
+
|
189 |
+
|
190 |
+
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
|
191 |
+
warmup_schedule = np.array([])
|
192 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
193 |
+
if warmup_epochs > 0:
|
194 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
195 |
+
|
196 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
197 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
198 |
+
|
199 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
200 |
+
assert len(schedule) == epochs * niter_per_ep
|
201 |
+
return schedule
|
202 |
+
|
203 |
+
|
204 |
+
class GaussianBlur(object):
|
205 |
+
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
|
206 |
+
|
207 |
+
def __init__(self, sigma=[.1, 2.]):
|
208 |
+
self.sigma = sigma
|
209 |
+
|
210 |
+
def __call__(self, x):
|
211 |
+
sigma = random.uniform(self.sigma[0], self.sigma[1])
|
212 |
+
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
|
213 |
+
return x
|