Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import os | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from Qformer import BertConfig, BertLMHeadModel | |
| try: | |
| import torch_npu | |
| from torch_npu.contrib import transfer_to_npu | |
| DEVICE_TYPE = "npu" | |
| except ModuleNotFoundError: | |
| DEVICE_TYPE = "cuda" | |
| def generate_length_mask(lens, max_length=None): | |
| lens = torch.as_tensor(lens) | |
| N = lens.size(0) | |
| if max_length is None: | |
| max_length = max(lens) | |
| idxs = torch.arange(max_length).repeat(N).view(N, max_length) | |
| idxs = idxs.to(lens.device) | |
| mask = (idxs < lens.view(-1, 1)).int() | |
| return mask | |
| class QformerBridgeNet(torch.nn.Module): | |
| def __init__(self, Qformer_model_name: str = "bert-base-uncased", num_query_token: int = 32, | |
| hiddin_size: int = 1024, speech_width: int = 1024, freeze_QFormer: bool = True, | |
| load_from_pretrained: str = None): | |
| super().__init__() | |
| self.Qformer_model_name = Qformer_model_name | |
| self.audio_Qformer, self.audio_query_tokens, encoder_config = self.init_Qformer(num_query_token=num_query_token, speech_width=speech_width) | |
| self.audio_Qformer.cls = None | |
| self.audio_Qformer.bert.embeddings.word_embeddings = None | |
| self.audio_Qformer.bert.embeddings.position_embeddings = None | |
| for layer in self.audio_Qformer.bert.encoder.layer: | |
| layer.output = None | |
| layer.intermediate = None | |
| self.freeze_QFormer = freeze_QFormer | |
| if freeze_QFormer: | |
| for name, param in self.audio_Qformer.named_parameters(): | |
| param.requires_grad = False | |
| self.audio_Qformer.eval() | |
| self.audio_query_tokens.requires_grad = False | |
| self.hiddin_projection = torch.nn.Linear(encoder_config.hidden_size, hiddin_size) | |
| #torch.nn.init.xavier_uniform_(self.hiddin_projection.weight, gain=torch.nn.init.calculate_gain("relu")) | |
| if load_from_pretrained: | |
| state_dict = torch.load(load_from_pretrained) | |
| del_key = ["projection.weight", "projection.bias"] | |
| del_state_dict = {k:v for k, v in state_dict.items() if k not in del_key} | |
| self.load_state_dict(del_state_dict) | |
| print("Load adaptor_model_pt from", load_from_pretrained) | |
| def init_Qformer(self, num_query_token, speech_width, num_hidden_layers=2, cross_attention_freq=2): | |
| encoder_config = BertConfig.from_pretrained(self.Qformer_model_name) | |
| encoder_config.num_hidden_layers = num_hidden_layers | |
| encoder_config.encoder_width = speech_width | |
| # insert cross-attention layer every other block | |
| encoder_config.add_cross_attention = True | |
| encoder_config.cross_attention_freq = cross_attention_freq | |
| encoder_config.query_length = num_query_token | |
| Qformer = BertLMHeadModel(config=encoder_config) | |
| query_tokens = nn.Parameter( | |
| torch.zeros(1, num_query_token, encoder_config.hidden_size) | |
| ) | |
| query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) | |
| return Qformer, query_tokens, encoder_config | |
| def hidden(self, batch,): | |
| audio_feature, lens = batch['embed'], batch['embed_len'] | |
| frame_atts = generate_length_mask(lens).to(audio_feature.device) | |
| audio_query_tokens=self.audio_query_tokens.expand(audio_feature.shape[0], -1, -1) | |
| #frame_atts = torch.ones(audio_feature.size()[:-1], dtype=torch.long).to(audio_feature.device) | |
| #print(audio_query_tokens.shape, audio_feature.shape, frame_atts.shape) | |
| audio_query_output=self.audio_Qformer.bert( | |
| query_embeds=audio_query_tokens, #[32,768] | |
| encoder_hidden_states=audio_feature, | |
| encoder_attention_mask=frame_atts, | |
| return_dict=True, | |
| ) | |
| audio_hidden = audio_query_output.last_hidden_state | |
| return audio_hidden | |
| def forward(self, batch) -> torch.Tensor: | |
| with torch.no_grad(), torch.amp.autocast( | |
| device_type=DEVICE_TYPE, enabled=False | |
| ): | |
| x = self.hidden(batch) | |
| x = self.hiddin_projection(x) | |
| mask = torch.ones(x.shape[:2]) | |
| mask = (mask == 1).to(x.device) | |
| return {"output": x, "mask": mask} | |
| if __name__ == '__main__': | |
| text_encoder = T5TextEncoder() | |
| text = ["a man is speaking", "a woman is singing while a dog is barking"] | |
| text_encoder.eval() | |
| with torch.no_grad(): | |
| output = text_encoder(text) | |