Spaces:
Running
Running
import os | |
from typing import List, Union | |
import torch | |
from torch import Tensor, nn | |
class ClipTextEncoder(nn.Module): | |
def __init__( | |
self, | |
modelpath: str='openai/clip-vit-large-patch14', # clip-vit-base-patch32 | |
finetune: bool = False, | |
**kwargs | |
) -> None: | |
super().__init__() | |
from transformers import logging | |
from transformers import AutoModel, AutoTokenizer | |
logging.set_verbosity_error() | |
# Tokenizer | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
self.tokenizer = AutoTokenizer.from_pretrained(modelpath) | |
self.text_model = AutoModel.from_pretrained(modelpath) | |
# Don't train the model | |
if not finetune: | |
self.text_model.training = False | |
for p in self.text_model.parameters(): | |
p.requires_grad = False | |
# Then configure the model | |
self.max_length = self.tokenizer.model_max_length | |
self.text_encoded_dim = self.text_model.config.text_config.hidden_size | |
def forward(self, texts: List[str]): | |
# get prompt text embeddings | |
text_inputs = self.tokenizer( | |
texts, | |
padding="max_length", | |
truncation=True, | |
max_length=self.max_length, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids.to(self.text_model.device) | |
txt_att_mask = text_inputs.attention_mask.to(self.text_model.device) | |
# split into max length Clip can handle | |
if text_input_ids.shape[-1] > self.tokenizer.model_max_length: | |
text_input_ids = text_input_ids[:, :self.tokenizer. | |
model_max_length] | |
# use pooled ouuput if latent dim is two-dimensional | |
# pooled = 0 if self.latent_dim[0] == 1 else 1 # (bs, seq_len, text_encoded_dim) -> (bs, text_encoded_dim) | |
# text encoder forward, clip must use get_text_features | |
# (batch_Size, seq_length , text_encoded_dim) | |
text_embeddings = self.text_model.text_model(text_input_ids, | |
# attention_mask=txt_att_mask | |
).last_hidden_state | |
return text_embeddings, txt_att_mask.bool() | |