Yixuan Li
first commit
4853fdc
import torch
import torch.nn as nn
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from Qformer import BertConfig, BertLMHeadModel
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu
DEVICE_TYPE = "npu"
except ModuleNotFoundError:
DEVICE_TYPE = "cuda"
def generate_length_mask(lens, max_length=None):
lens = torch.as_tensor(lens)
N = lens.size(0)
if max_length is None:
max_length = max(lens)
idxs = torch.arange(max_length).repeat(N).view(N, max_length)
idxs = idxs.to(lens.device)
mask = (idxs < lens.view(-1, 1)).int()
return mask
class QformerBridgeNet(torch.nn.Module):
def __init__(self, Qformer_model_name: str = "bert-base-uncased", num_query_token: int = 32,
hiddin_size: int = 1024, speech_width: int = 1024, freeze_QFormer: bool = True,
load_from_pretrained: str = None):
super().__init__()
self.Qformer_model_name = Qformer_model_name
self.audio_Qformer, self.audio_query_tokens, encoder_config = self.init_Qformer(num_query_token=num_query_token, speech_width=speech_width)
self.audio_Qformer.cls = None
self.audio_Qformer.bert.embeddings.word_embeddings = None
self.audio_Qformer.bert.embeddings.position_embeddings = None
for layer in self.audio_Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
self.freeze_QFormer = freeze_QFormer
if freeze_QFormer:
for name, param in self.audio_Qformer.named_parameters():
param.requires_grad = False
self.audio_Qformer.eval()
self.audio_query_tokens.requires_grad = False
self.hiddin_projection = torch.nn.Linear(encoder_config.hidden_size, hiddin_size)
#torch.nn.init.xavier_uniform_(self.hiddin_projection.weight, gain=torch.nn.init.calculate_gain("relu"))
if load_from_pretrained:
state_dict = torch.load(load_from_pretrained)
del_key = ["projection.weight", "projection.bias"]
del_state_dict = {k:v for k, v in state_dict.items() if k not in del_key}
self.load_state_dict(del_state_dict)
print("Load adaptor_model_pt from", load_from_pretrained)
def init_Qformer(self, num_query_token, speech_width, num_hidden_layers=2, cross_attention_freq=2):
encoder_config = BertConfig.from_pretrained(self.Qformer_model_name)
encoder_config.num_hidden_layers = num_hidden_layers
encoder_config.encoder_width = speech_width
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel(config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return Qformer, query_tokens, encoder_config
def hidden(self, batch,):
audio_feature, lens = batch['embed'], batch['embed_len']
frame_atts = generate_length_mask(lens).to(audio_feature.device)
audio_query_tokens=self.audio_query_tokens.expand(audio_feature.shape[0], -1, -1)
#frame_atts = torch.ones(audio_feature.size()[:-1], dtype=torch.long).to(audio_feature.device)
#print(audio_query_tokens.shape, audio_feature.shape, frame_atts.shape)
audio_query_output=self.audio_Qformer.bert(
query_embeds=audio_query_tokens, #[32,768]
encoder_hidden_states=audio_feature,
encoder_attention_mask=frame_atts,
return_dict=True,
)
audio_hidden = audio_query_output.last_hidden_state
return audio_hidden
def forward(self, batch) -> torch.Tensor:
with torch.no_grad(), torch.amp.autocast(
device_type=DEVICE_TYPE, enabled=False
):
x = self.hidden(batch)
x = self.hiddin_projection(x)
mask = torch.ones(x.shape[:2])
mask = (mask == 1).to(x.device)
return {"output": x, "mask": mask}
if __name__ == '__main__':
text_encoder = T5TextEncoder()
text = ["a man is speaking", "a woman is singing while a dog is barking"]
text_encoder.eval()
with torch.no_grad():
output = text_encoder(text)