fffiloni's picture
Upload 244 files
b3f324b verified
raw
history blame contribute delete
No virus
1.61 kB
import torch
from torch import nn
from transformers import T5EncoderModel, CLIPModel, CLIPProcessor
from opensora.utils.utils import get_precision
class T5Wrapper(nn.Module):
def __init__(self, args):
super(T5Wrapper, self).__init__()
self.model_name = args.text_encoder_name
dtype = get_precision(args)
t5_model_kwargs = {'cache_dir': './cache_dir', 'low_cpu_mem_usage': True, 'torch_dtype': dtype}
self.text_enc = T5EncoderModel.from_pretrained(self.model_name, **t5_model_kwargs).eval()
def forward(self, input_ids, attention_mask):
text_encoder_embs = self.text_enc(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state']
return text_encoder_embs.detach()
class CLIPWrapper(nn.Module):
def __init__(self, args):
super(CLIPWrapper, self).__init__()
self.model_name = args.text_encoder_name
dtype = get_precision(args)
model_kwargs = {'cache_dir': './cache_dir', 'low_cpu_mem_usage': True, 'torch_dtype': dtype}
self.text_enc = CLIPModel.from_pretrained(self.model_name, **model_kwargs).eval()
def forward(self, input_ids, attention_mask):
text_encoder_embs = self.text_enc.get_text_features(input_ids=input_ids, attention_mask=attention_mask)
return text_encoder_embs.detach()
text_encoder = {
'DeepFloyd/t5-v1_1-xxl': T5Wrapper,
'openai/clip-vit-large-patch14': CLIPWrapper
}
def get_text_enc(args):
"""deprecation"""
text_enc = text_encoder.get(args.text_encoder_name, None)
assert text_enc is not None
return text_enc(args)