fffiloni commited on
Commit
3bbe5bd
1 Parent(s): 5847782

Upload 6 files

Browse files
xdecoder/language/build.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .registry import model_entrypoints
2
+ from .registry import is_model
3
+
4
+
5
+ def build_language_encoder(config, **kwargs):
6
+ model_name = config['MODEL']['TEXT']['ARCH']
7
+
8
+ if not is_model(model_name):
9
+ raise ValueError(f'Unkown model: {model_name}')
10
+
11
+ return model_entrypoints(model_name)(config, **kwargs)
xdecoder/language/fixvlpencoder.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib.metadata import requires
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .registry import register_model
6
+ from .vlpencoder import LanguageEncoder
7
+
8
+ class FixLanguageEncoder(LanguageEncoder):
9
+
10
+ def __init__(
11
+ self,
12
+ *args, **kwargs):
13
+ super(FixLanguageEncoder, self).__init__(*args, **kwargs)
14
+ self.logit_scale = nn.Parameter(torch.ones([]), requires_grad=False)
15
+
16
+ @torch.no_grad()
17
+ def get_text_embeddings(self, *args, **kwargs):
18
+ return super().get_text_embeddings(*args, **kwargs)
19
+
20
+ @torch.no_grad()
21
+ def get_text_token_embeddings(self, *args, **kwargs):
22
+ return super().get_text_token_embeddings(*args, **kwargs)
23
+
24
+ @torch.no_grad()
25
+ def forward_language(self, *args, **kwargs):
26
+ return super().forward_language(*args, **kwargs)
27
+
28
+ @torch.no_grad()
29
+ def forward_language_token(self, *args, **kwargs):
30
+ return super().forward_language_token(*args, **kwargs)
31
+
32
+
33
+ @register_model
34
+ def get_language_model(cfg, **kwargs):
35
+ return FixLanguageEncoder(cfg)
xdecoder/language/loss.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from distutils import log
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torch.distributed as dist
7
+
8
+ from einops import rearrange, repeat
9
+ from timm.loss import SoftTargetCrossEntropy
10
+
11
+ soft_cross_entropy = SoftTargetCrossEntropy()
12
+
13
+ def is_dist_initialized():
14
+ return torch.distributed.is_initialized()
15
+
16
+ def get_world_size():
17
+ if is_dist_initialized():
18
+ return torch.distributed.get_world_size()
19
+ return 1
20
+
21
+ def get_rank():
22
+ if is_dist_initialized():
23
+ return dist.get_rank()
24
+ return 0
25
+
26
+ def all_gather_grad(x):
27
+ if get_world_size() > 1:
28
+ all_x = [torch.zeros_like(x) for _ in range(get_world_size())]
29
+ torch.distributed.all_gather(all_x, x)
30
+ all_x[torch.distributed.get_rank()] = x
31
+ x = torch.cat(all_x, dim=0)
32
+ return x
33
+
34
+ def vl_multilabel_contrastive_loss(image_feat, text_feat, temperature=1):
35
+ """
36
+ Args:
37
+ image_feat (torch.Tensor): shape [B, L1, C] # B: batch_size, L1: 1, C: 256
38
+ text_feat (torch.Tensor): shape [B, L2, C] # B:batch_size, L2: number of selected nouns, C: 256
39
+
40
+ Returns:
41
+ """
42
+ # [B, L1, C], L1 = 1
43
+ # image_feat = F.normalize(image_feat, dim=-1)
44
+ # [B, L2, C]
45
+ # text_feat = F.normalize(text_feat, dim=-1)
46
+ # HACK: normalize outside
47
+
48
+ # [B, L1, L2]
49
+ dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
50
+ # [B, L2, L1]
51
+ dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
52
+
53
+ batch = image_feat.shape[0]
54
+ img_len = image_feat.shape[1]
55
+ text_len = text_feat.shape[1]
56
+ # [B, L1, L2]
57
+ pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
58
+ # [B, L2, L1]
59
+ pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
60
+
61
+ image_x = rearrange(image_feat, 'b l c -> (b l) c')
62
+ text_x = rearrange(text_feat, 'b l c -> (b l) c')
63
+
64
+ logits_per_img = image_x @ all_gather_grad(text_x).t()
65
+ logits_per_text = text_x @ all_gather_grad(image_x).t()
66
+
67
+ # get label globally
68
+ # [B, L1, B, L2, W]
69
+ labels_per_img = F.one_hot(
70
+ torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * get_rank(),
71
+ num_classes=get_world_size()).to(image_x.dtype)
72
+ labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
73
+ torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
74
+ # [BxL1, WxBxL2]
75
+ labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
76
+ # [B, L2, B, L1, W]
77
+ labels_per_text = F.one_hot(
78
+ torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * get_rank(),
79
+ num_classes=get_world_size()).to(text_x.dtype)
80
+ labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
81
+ torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
82
+ # [BxL2, WxBxL1]
83
+ labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
84
+
85
+ logit_scale = temperature.exp().clamp(max=100)
86
+
87
+ loss_img = soft_cross_entropy(logit_scale * logits_per_img, labels_per_img)
88
+ loss_text = soft_cross_entropy(logit_scale * logits_per_text, labels_per_text)
89
+
90
+ loss = 0.5 * (loss_img + loss_text)
91
+ return loss
92
+
93
+ def vl_contrastive_loss(image_feat, text_feat, temperature=1):
94
+ # if image_id or text_id is None, it should be None across all GPUs
95
+ # image_feat = F.normalize(image_feat, dim=1)
96
+ # text_feat = F.normalize(text_feat, dim=1)
97
+ # handle normalization outside
98
+
99
+ # add the following 4 lines
100
+ image_feat = all_gather_grad(image_feat)
101
+ text_feat = all_gather_grad(text_feat)
102
+
103
+ logits = torch.matmul(image_feat, text_feat.t())
104
+ logit_scale = temperature.exp().clamp(max=100)
105
+
106
+ gt = torch.arange(logits.shape[0], device=logits.device)
107
+ loss1 = F.cross_entropy(logit_scale * logits, gt)
108
+ loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
109
+ return (loss1 + loss2) / 2 # scale it up by the number of GPUs
110
+
111
+
112
+ def all_gather_pickle(data, device):
113
+ """
114
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
115
+ Args:
116
+ data: any picklable object
117
+ Returns:
118
+ list[data]: list of data gathered from each rank
119
+ """
120
+ world_size = get_world_size()
121
+ if world_size == 1:
122
+ return [data]
123
+
124
+ # serialized to a Tensor
125
+ buffer = pickle.dumps(data)
126
+ storage = torch.ByteStorage.from_buffer(buffer)
127
+ tensor = torch.ByteTensor(storage).to(device)
128
+
129
+ # obtain Tensor size of each rank
130
+ local_size = torch.LongTensor([tensor.numel()]).cuda()
131
+ size_list = [torch.LongTensor([0]).cuda() for _ in range(world_size)]
132
+ dist.all_gather(size_list, local_size)
133
+ size_list = [int(size.item()) for size in size_list]
134
+ max_size = max(size_list)
135
+
136
+ # receiving Tensor from all ranks
137
+ # we pad the tensor because torch all_gather does not support
138
+ # gathering tensors of different shapes
139
+ tensor_list = []
140
+ for _ in size_list:
141
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).cuda())
142
+ if local_size != max_size:
143
+ padding = torch.ByteTensor(size=(max_size - local_size,)).cuda()
144
+ tensor = torch.cat((tensor, padding), dim=0)
145
+ dist.all_gather(tensor_list, tensor)
146
+
147
+ data_list = []
148
+ for size, tensor in zip(size_list, tensor_list):
149
+ buffer = tensor.cpu().numpy().tobytes()[:size]
150
+ data_list.append(pickle.loads(buffer))
151
+
152
+ return data_list
153
+
154
+ def all_gather_arbitary_tensor(tensor):
155
+ if get_world_size() > 1:
156
+ device = tensor.device
157
+ tensor_batch = all_gather_pickle(tensor.cpu(), device)
158
+ tensor_batch = [x.to(device) for x in tensor_batch]
159
+ tensor_batch[torch.distributed.get_rank()] = tensor
160
+ tensor_batch = torch.cat(tensor_batch, dim=0)
161
+ else:
162
+ tensor_batch = tensor
163
+ return tensor_batch
164
+
165
+ def ql_contrastive_loss(image_feat, text_feat, temperature=1):
166
+ # add the following 4 lines
167
+ image_feat = all_gather_arbitary_tensor(image_feat)
168
+ text_feat = all_gather_arbitary_tensor(text_feat)
169
+
170
+ logits = torch.matmul(image_feat, text_feat.t())
171
+ logit_scale = temperature.exp().clamp(max=100)
172
+
173
+ gt = torch.arange(logits.shape[0], device=logits.device)
174
+ loss1 = F.cross_entropy(logit_scale * logits, gt)
175
+ loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
176
+ return (loss1 + loss2) / 2 # scale it up by the number of GPUs
177
+
178
+ def vl_similarity(image_feat, text_feat, temperature=1):
179
+ # Only support single GPU for now.
180
+ logits = torch.matmul(image_feat, text_feat.t())
181
+ logits = temperature.exp().clamp(max=100) * logits
182
+ return logits
183
+
184
+ def ql_multi_contrastive_loss(image_feat, text_feat, text_hash, temperature=1):
185
+ # add the following 4 lines
186
+ image_feat = all_gather_arbitary_tensor(image_feat)
187
+ text_feat = all_gather_arbitary_tensor(text_feat)
188
+
189
+ text_hash_batch = all_gather_pickle(text_hash, text_feat.device)
190
+ text_hash_all = torch.cat(text_hash_batch)
191
+
192
+ text_hash_all_unique = torch.unique(text_hash_all).tolist()
193
+ gt = torch.zeros((image_feat.shape[0], len(text_hash_all_unique)), device=text_feat.device)
194
+ text_hash_all = text_hash_all.tolist()
195
+ text_feat_unique = torch.stack([text_feat[text_hash_all.index(txt)] for txt in text_hash_all_unique])
196
+
197
+ for idx, txt in enumerate(text_hash_all):
198
+ gt[idx][text_hash_all_unique.index(txt)] = 1
199
+
200
+ logits = torch.matmul(image_feat, text_feat_unique.t())
201
+ logits = logits*temperature.exp().clamp(max=100)
202
+
203
+ loss_img = soft_cross_entropy(logits, gt)
204
+ loss_text = soft_cross_entropy(logits.t(), gt.t() / gt.t().sum(-1, keepdim=True))
205
+
206
+ loss = 0.7 * loss_img + 0.3 * loss_text
207
+ return loss
208
+
209
+ def image_text_contrastive_loss_queue(image_feat_inp, text_feat_inp, lang_enc, training):
210
+ # add the following 4 lines
211
+ image_feat = all_gather_grad(image_feat_inp.contiguous())
212
+ text_feat = all_gather_grad(text_feat_inp.contiguous())
213
+
214
+ image_feat = image_feat / (image_feat.norm(dim=-1, keepdim=True) + 1e-7)
215
+ text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-7)
216
+
217
+ temperature = lang_enc.logit_scale
218
+ logits = torch.matmul(image_feat, text_feat.t())
219
+ logit_scale = temperature.exp().clamp(max=100)
220
+
221
+ gt = torch.arange(logits.shape[0], device=logits.device)
222
+ loss1 = F.cross_entropy(logit_scale * logits, gt)
223
+ loss2 = F.cross_entropy(logit_scale * logits.t(), gt)
224
+
225
+ return (loss1 + loss2) / 2 # scale it up by the number of GPUs
xdecoder/language/misc.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import nltk
4
+ nltk.data.path.append('/mnt/data/nltk_data')
5
+ import numpy as np
6
+
7
+ from utils.constants import IMAGENET_DEFAULT_TEMPLATES
8
+
9
+
10
+ def get_tag(tokenized, tags):
11
+ if not isinstance(tags, (list, tuple)):
12
+ tags = [tags]
13
+ ret = []
14
+ for (word, pos) in nltk.pos_tag(tokenized):
15
+ for tag in tags:
16
+ if pos == tag:
17
+ ret.append(word)
18
+ return ret
19
+
20
+ def get_noun_phrase(tokenized):
21
+ # Taken from Su Nam Kim Paper...
22
+ grammar = r"""
23
+ NBAR:
24
+ {<NN.*|JJ>*<NN.*>} # Nouns and Adjectives, terminated with Nouns
25
+
26
+ NP:
27
+ {<NBAR>}
28
+ {<NBAR><IN><NBAR>} # Above, connected with in/of/etc...
29
+ """
30
+ chunker = nltk.RegexpParser(grammar)
31
+
32
+ chunked = chunker.parse(nltk.pos_tag(tokenized))
33
+ continuous_chunk = []
34
+ current_chunk = []
35
+
36
+ for subtree in chunked:
37
+ if isinstance(subtree, nltk.Tree):
38
+ current_chunk.append(' '.join([token for token, pos in subtree.leaves()]))
39
+ elif current_chunk:
40
+ named_entity = ' '.join(current_chunk)
41
+ if named_entity not in continuous_chunk:
42
+ continuous_chunk.append(named_entity)
43
+ current_chunk = []
44
+ else:
45
+ continue
46
+
47
+ return continuous_chunk
48
+
49
+ def text_noun_with_prompt_all(text, phrase_prob=0.0, append_text=True):
50
+ tokenized = nltk.word_tokenize(text)
51
+
52
+ if random.random() >= phrase_prob:
53
+ nouns = get_tag(tokenized, ['NN', 'NNS', 'NNP'])
54
+ else:
55
+ nouns = get_noun_phrase(tokenized)
56
+
57
+
58
+ prompt_texts = [np.random.choice(IMAGENET_DEFAULT_TEMPLATES).format(noun) for noun in nouns]
59
+
60
+ if append_text:
61
+ prompt_texts += [text]
62
+ nouns += [text]
63
+
64
+ return prompt_texts, nouns
xdecoder/language/registry.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _model_entrypoints = {}
2
+
3
+ def register_model(fn):
4
+ module_name_split = fn.__module__.split('.')
5
+ model_name = module_name_split[-1]
6
+ _model_entrypoints[model_name] = fn
7
+ return fn
8
+
9
+ def model_entrypoints(model_name):
10
+ return _model_entrypoints[model_name]
11
+
12
+ def is_model(model_name):
13
+ return model_name in _model_entrypoints
xdecoder/language/vlpencoder.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from timm.models.layers import trunc_normal_
7
+
8
+ from .registry import register_model
9
+ from ..utils import configurable
10
+ from .LangEncoder import build_tokenizer, build_lang_encoder
11
+ from utils.misc import prompt_engineering, get_prompt_templates
12
+
13
+
14
+ class LanguageEncoder(nn.Module):
15
+
16
+ @configurable
17
+ def __init__(
18
+ self,
19
+ tokenizer,
20
+ tokenizer_type,
21
+ lang_encoder,
22
+ lang_projection,
23
+ max_token_num,
24
+ ):
25
+ super().__init__()
26
+ self.tokenizer = tokenizer
27
+ self.tokenizer_type = tokenizer_type
28
+ self.lang_encoder = lang_encoder
29
+ self.lang_proj = lang_projection
30
+ self.max_token_num = max_token_num
31
+ self.logit_scale = nn.Parameter(torch.ones([]))
32
+
33
+ @classmethod
34
+ def from_config(cls, cfg):
35
+ tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
36
+ tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']
37
+ lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])
38
+ max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
39
+
40
+ dim_lang = cfg['MODEL']['TEXT']['WIDTH']
41
+ dim_projection = cfg['MODEL']['DIM_PROJ']
42
+ lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))
43
+ trunc_normal_(lang_projection, std=.02)
44
+
45
+ return {
46
+ "tokenizer": tokenizer,
47
+ "tokenizer_type": tokenizer_type,
48
+ "lang_encoder": lang_encoder,
49
+ "lang_projection": lang_projection,
50
+ "max_token_num": max_token_num,
51
+ }
52
+
53
+ def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True):
54
+ if not is_eval:
55
+ if prompt:
56
+ # randomly sample one template
57
+ arbitary_concepts = [
58
+ prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
59
+ for label in range(len(class_names))
60
+ ]
61
+ if add_bgd:
62
+ arbitary_concepts.append("A background in coco.")
63
+ else:
64
+ arbitary_concepts = class_names
65
+
66
+ input_ids = []
67
+ attention_masks = []
68
+ for txt in arbitary_concepts:
69
+ tokens = self.tokenizer(
70
+ txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
71
+ )
72
+ tokens['input_ids'].squeeze_()
73
+ tokens['attention_mask'].squeeze_()
74
+
75
+ input_ids.append(tokens['input_ids'])
76
+ attention_masks.append(tokens['attention_mask'])
77
+
78
+ arbitary_tokens = torch.stack(input_ids)
79
+ arbitary_attention_masks = torch.stack(attention_masks)
80
+
81
+ text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)
82
+ setattr(self, '{}_text_embeddings'.format(name), text_emb)
83
+ else:
84
+ with torch.no_grad():
85
+ def extract_mean_emb(txts):
86
+ tokens = self.tokenizer(
87
+ txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
88
+ )
89
+ clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)
90
+ clss_embedding = clss_embedding.mean(dim=0)
91
+ clss_embedding /= clss_embedding.norm()
92
+ return clss_embedding
93
+
94
+ templates = get_prompt_templates()
95
+ clss_embeddings = []
96
+ if prompt:
97
+ for clss in class_names:
98
+ txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]
99
+ clss_embeddings.append(extract_mean_emb(txts))
100
+ else:
101
+ clss_embeddings.append(extract_mean_emb(class_names))
102
+
103
+ if add_bgd:
104
+ txts = ["A background in coco."]
105
+ clss_embeddings.append(extract_mean_emb(txts))
106
+
107
+ text_emb = torch.stack(clss_embeddings, dim=0)
108
+ setattr(self, '{}_text_embeddings'.format(name), text_emb)
109
+
110
+ def get_text_token_embeddings(self, txts, name='default', token=False, norm=False):
111
+ if not token:
112
+ tokens = self.tokenizer(
113
+ txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
114
+ )
115
+ tokens = {key: value.cuda() for key, value in tokens.items()}
116
+ else:
117
+ tokens = txts
118
+ token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm)
119
+ ret = {"tokens": tokens,
120
+ "token_emb": token_emb,
121
+ "class_emb": class_emb,}
122
+ setattr(self, '{}_token_embeddings'.format(name), ret)
123
+ return ret
124
+
125
+ def forward_language(self, texts, norm=True):
126
+ x = self.lang_encoder(*texts)
127
+ x = x['last_hidden_state']
128
+
129
+ if self.tokenizer_type == 'clip':
130
+ x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]
131
+ else:
132
+ x = x[:, 0]
133
+
134
+ x = x @ self.lang_proj
135
+ if norm:
136
+ x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)
137
+ return x
138
+
139
+ def forward_language_token(self, texts, norm=False):
140
+ x = self.lang_encoder(*texts)
141
+ token_x = x['last_hidden_state']
142
+
143
+ if self.tokenizer_type == 'clip':
144
+ class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)]
145
+ else:
146
+ class_x = token_x[:, 0]
147
+
148
+ class_x = class_x @ self.lang_proj
149
+ token_x = token_x @ self.lang_proj
150
+
151
+ if norm:
152
+ class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7)
153
+ token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7)
154
+
155
+ return token_x, class_x
156
+
157
+ def compute_similarity(self, v_emb, name='default', fake=False):
158
+ if fake:
159
+ return None
160
+ v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
161
+ t_emb = getattr(self, '{}_text_embeddings'.format(name))
162
+ output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)
163
+ return output
164
+
165
+
166
+ @register_model
167
+ def get_language_model(cfg, **kwargs):
168
+ return LanguageEncoder(cfg)