motionfix-demo / text_encoder.py
atnikos's picture
first mvp
7d87cc1
import os
from typing import List, Union
import torch
from torch import Tensor, nn
class ClipTextEncoder(nn.Module):
def __init__(
self,
modelpath: str='openai/clip-vit-large-patch14', # clip-vit-base-patch32
finetune: bool = False,
**kwargs
) -> None:
super().__init__()
from transformers import logging
from transformers import AutoModel, AutoTokenizer
logging.set_verbosity_error()
# Tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
self.text_model = AutoModel.from_pretrained(modelpath)
# Don't train the model
if not finetune:
self.text_model.training = False
for p in self.text_model.parameters():
p.requires_grad = False
# Then configure the model
self.max_length = self.tokenizer.model_max_length
self.text_encoded_dim = self.text_model.config.text_config.hidden_size
def forward(self, texts: List[str]):
# get prompt text embeddings
text_inputs = self.tokenizer(
texts,
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(self.text_model.device)
txt_att_mask = text_inputs.attention_mask.to(self.text_model.device)
# split into max length Clip can handle
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
text_input_ids = text_input_ids[:, :self.tokenizer.
model_max_length]
# use pooled ouuput if latent dim is two-dimensional
# pooled = 0 if self.latent_dim[0] == 1 else 1 # (bs, seq_len, text_encoded_dim) -> (bs, text_encoded_dim)
# text encoder forward, clip must use get_text_features
# (batch_Size, seq_length , text_encoded_dim)
text_embeddings = self.text_model.text_model(text_input_ids,
# attention_mask=txt_att_mask
).last_hidden_state
return text_embeddings, txt_att_mask.bool()