|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
import torch |
|
from safetensors.torch import load_file |
|
from transformers.models.t5.configuration_t5 import T5Config |
|
from transformers.models.t5.modeling_t5 import T5Stack |
|
|
|
|
|
class AniMemoryT5(torch.nn.Module): |
|
def __init__(self, config: T5Config, embed_tokens=None): |
|
super().__init__() |
|
self.encoder = T5Stack(config, embed_tokens) |
|
self.embed_tokens_encoder = torch.nn.Embedding(250002, 4096, padding_idx=1) |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
pretrained_model_name_or_path, |
|
subfolder="", |
|
embed_tokens=None, |
|
emb_name="weights.safetensors", |
|
torch_dtype=torch.float16, |
|
): |
|
cls.dtype = torch_dtype |
|
config = T5Stack.config_class.from_pretrained( |
|
pretrained_model_name_or_path, subfolder=subfolder |
|
) |
|
model = cls(config=config, embed_tokens=embed_tokens) |
|
model.encoder = T5Stack.from_pretrained( |
|
pretrained_model_name_or_path, subfolder=subfolder |
|
) |
|
embed_tokens_encoder_path = load_file( |
|
os.path.join(pretrained_model_name_or_path, subfolder, emb_name) |
|
) |
|
model.embed_tokens_encoder.load_state_dict(embed_tokens_encoder_path) |
|
model.encoder.to(torch_dtype) |
|
model.embed_tokens_encoder.to(torch_dtype) |
|
return model |
|
|
|
def to(self, *args, **kwargs): |
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( |
|
*args, **kwargs |
|
) |
|
super(AniMemoryT5, self).to(*args, **kwargs) |
|
self.dtype = dtype if dtype is not None else self.dtype |
|
self.device = device if device is not None else self.device |
|
return self |
|
|
|
def make_attn_mask(self, attn_mask): |
|
seq_len = attn_mask.shape[1] |
|
query = attn_mask.unsqueeze(1).float() |
|
attn_mask = ( |
|
query.repeat([1, seq_len, 1]).unsqueeze(1).repeat([1, self.num_head, 1, 1]) |
|
) |
|
attn_mask = attn_mask.view([-1, seq_len, seq_len]) |
|
return attn_mask |
|
|
|
def forward(self, text, attention_mask): |
|
embeddings = self.embed_tokens_encoder(text) |
|
encoder_outputs = self.encoder( |
|
inputs_embeds=embeddings, |
|
attention_mask=attention_mask, |
|
output_hidden_states=True, |
|
) |
|
hidden_states = encoder_outputs.hidden_states[-2] |
|
hidden_states = self.encoder.final_layer_norm(hidden_states) |
|
return hidden_states, hidden_states |
|
|