Spaces:
Runtime error
Runtime error
''' | |
LinCIR | |
Copyright (c) 2023-present NAVER Corp. | |
CC BY-NC-4.0 (https://creativecommons.org/licenses/by-nc/4.0/) | |
''' | |
import torch | |
from clip.model import CLIP | |
from transformers import CLIPTextModelWithProjection | |
def _make_causal_mask( | |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 | |
): | |
""" | |
Make causal mask used for bi-directional self-attention. | |
Copy-paste from https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/models/clip/modeling_clip.py#L679-L693 | |
""" | |
bsz, tgt_len = input_ids_shape | |
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) | |
mask_cond = torch.arange(mask.size(-1), device=device) | |
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) | |
mask = mask.to(dtype) | |
if past_key_values_length > 0: | |
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) | |
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) | |
def encode_with_pseudo_tokens_HF(clip_model: CLIPTextModelWithProjection, text: torch.Tensor, pseudo_tokens: torch.Tensor, | |
num_tokens=1, return_last_states=False) -> torch.Tensor: | |
x = clip_model.text_model.embeddings.token_embedding(text).type(clip_model.dtype) # [batch_size, n_ctx, d_model] | |
x = torch.where(text.unsqueeze(-1) == 259, | |
pseudo_tokens.unsqueeze(1).type(clip_model.dtype), | |
x) | |
x = x + clip_model.text_model.embeddings.position_embedding(clip_model.text_model.embeddings.position_ids) | |
_causal_attention_mask = _make_causal_mask(text.shape, x.dtype, device=x.device) | |
x = clip_model.text_model.encoder(inputs_embeds=x, | |
attention_mask=None, | |
causal_attention_mask=_causal_attention_mask, | |
output_attentions=False, | |
output_hidden_states=False, | |
return_dict=False) | |
x = x[0] | |
x_last = clip_model.text_model.final_layer_norm(x) | |
x = x_last[torch.arange(x_last.shape[0], device=x_last.device), | |
text.to(dtype=torch.int, device=x_last.device).argmax(dim=-1), | |
] | |
if hasattr(clip_model, 'text_projection'): | |
x = clip_model.text_projection(x) | |
if return_last_states: | |
return x, x_last | |
else: | |
return x | |