Demo / xdecoder /language /vlpencoder.py
MaureenZOU
init
e972e1f
import torch
from torch import nn
from torch.nn import functional as F
from timm.models.layers import trunc_normal_
from .registry import register_model
from ..utils import configurable
from .LangEncoder import build_tokenizer, build_lang_encoder
from utils.misc import prompt_engineering, get_prompt_templates
class LanguageEncoder(nn.Module):
@configurable
def __init__(
self,
tokenizer,
tokenizer_type,
lang_encoder,
lang_projection,
max_token_num,
):
super().__init__()
self.tokenizer = tokenizer
self.tokenizer_type = tokenizer_type
self.lang_encoder = lang_encoder
self.lang_proj = lang_projection
self.max_token_num = max_token_num
self.logit_scale = nn.Parameter(torch.ones([]))
@classmethod
def from_config(cls, cfg):
tokenizer = build_tokenizer(cfg['MODEL']['TEXT'])
tokenizer_type = cfg['MODEL']['TEXT']['TOKENIZER']
lang_encoder = build_lang_encoder(cfg['MODEL']['TEXT'], tokenizer, cfg['VERBOSE'])
max_token_num = cfg['MODEL']['TEXT']['CONTEXT_LENGTH']
dim_lang = cfg['MODEL']['TEXT']['WIDTH']
dim_projection = cfg['MODEL']['DIM_PROJ']
lang_projection = nn.Parameter(torch.empty(dim_lang, dim_projection))
trunc_normal_(lang_projection, std=.02)
return {
"tokenizer": tokenizer,
"tokenizer_type": tokenizer_type,
"lang_encoder": lang_encoder,
"lang_projection": lang_projection,
"max_token_num": max_token_num,
}
def get_text_embeddings(self, class_names, name='default', is_eval=False, add_bgd=False, prompt=True, norm=True):
if not is_eval:
if prompt:
# randomly sample one template
arbitary_concepts = [
prompt_engineering(class_names[label].replace('-other','').replace('-merged','').replace('-stuff',''), topk=10000, suffix='.') \
for label in range(len(class_names))
]
if add_bgd:
arbitary_concepts.append("A background in coco.")
else:
arbitary_concepts = class_names
input_ids = []
attention_masks = []
for txt in arbitary_concepts:
tokens = self.tokenizer(
txt, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
)
tokens['input_ids'].squeeze_()
tokens['attention_mask'].squeeze_()
input_ids.append(tokens['input_ids'])
attention_masks.append(tokens['attention_mask'])
arbitary_tokens = torch.stack(input_ids)
arbitary_attention_masks = torch.stack(attention_masks)
text_emb = self.forward_language((arbitary_tokens.cuda(), arbitary_attention_masks.cuda()), norm=norm)
setattr(self, '{}_text_embeddings'.format(name), text_emb)
else:
with torch.no_grad():
def extract_mean_emb(txts):
tokens = self.tokenizer(
txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
)
clss_embedding = self.forward_language((tokens['input_ids'].cuda(), tokens['attention_mask'].cuda()), norm=norm)
clss_embedding = clss_embedding.mean(dim=0)
clss_embedding /= clss_embedding.norm()
return clss_embedding
templates = get_prompt_templates()
clss_embeddings = []
if prompt:
for clss in class_names:
txts = [template.format(clss.replace('-other','').replace('-merged','').replace('-stuff','')) for template in templates]
clss_embeddings.append(extract_mean_emb(txts))
else:
clss_embeddings.append(extract_mean_emb(class_names))
if add_bgd:
txts = ["A background in coco."]
clss_embeddings.append(extract_mean_emb(txts))
text_emb = torch.stack(clss_embeddings, dim=0)
setattr(self, '{}_text_embeddings'.format(name), text_emb)
def get_text_token_embeddings(self, txts, name='default', token=False, norm=False):
if not token:
tokens = self.tokenizer(
txts, padding='max_length', truncation=True, max_length=self.max_token_num, return_tensors='pt'
)
tokens = {key: value.cuda() for key, value in tokens.items()}
else:
tokens = txts
token_emb, class_emb = self.forward_language_token((tokens['input_ids'], tokens['attention_mask']), norm=norm)
ret = {"tokens": tokens,
"token_emb": token_emb,
"class_emb": class_emb,}
setattr(self, '{}_token_embeddings'.format(name), ret)
return ret
def forward_language(self, texts, norm=True):
x = self.lang_encoder(*texts)
x = x['last_hidden_state']
if self.tokenizer_type == 'clip':
x = x[torch.arange(x.size(0)), texts[0].argmax(dim=-1)]
else:
x = x[:, 0]
x = x @ self.lang_proj
if norm:
x = x / (x.norm(dim=-1, keepdim=True) + 1e-7)
return x
def forward_language_token(self, texts, norm=False):
x = self.lang_encoder(*texts)
token_x = x['last_hidden_state']
if self.tokenizer_type == 'clip':
class_x = token_x[torch.arange(token_x.size(0)), texts[0].argmax(dim=-1)]
else:
class_x = token_x[:, 0]
class_x = class_x @ self.lang_proj
token_x = token_x @ self.lang_proj
if norm:
class_x = class_x / (class_x.norm(dim=-1, keepdim=True) + 1e-7)
token_x = token_x / (token_x.norm(dim=-1, keepdim=True) + 1e-7)
return token_x, class_x
def compute_similarity(self, v_emb, name='default', fake=False):
if fake:
return None
v_emb = v_emb / (v_emb.norm(dim=-1, keepdim=True) + 1e-7)
t_emb = getattr(self, '{}_text_embeddings'.format(name))
output = self.logit_scale.exp() * v_emb @ t_emb.unsqueeze(0).transpose(1, 2)
return output
@register_model
def get_language_model(cfg, **kwargs):
return LanguageEncoder(cfg)