NEXTGPT / code /model /layers.py
osamaifti's picture
Upload 83 files
7cdf421 verified
raw
history blame contribute delete
No virus
5.61 kB
import torch
from torch import nn
from .qformer import BertLMHeadModel, BertConfig
# from header import *
class TextFcLayer(nn.Module):
"""Layers used in mapping text embeddings to visual outputs."""
@classmethod
def init_Qformer(cls, num_query_token, vision_width, num_hidden_layers=2, cross_attention_freq=1):
encoder_config = BertConfig.from_pretrained("bert-base-uncased")
encoder_config.encoder_width = vision_width
encoder_config.num_hidden_layers = num_hidden_layers
# insert cross-attention layer every other block
encoder_config.add_cross_attention = True
encoder_config.cross_attention_freq = cross_attention_freq
encoder_config.query_length = num_query_token
Qformer = BertLMHeadModel.from_pretrained("bert-base-uncased", config=encoder_config)
query_tokens = nn.Parameter(
torch.zeros(1, num_query_token, encoder_config.hidden_size)
)
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
return Qformer, query_tokens
def __init__(self, in_dim: int, out_dim: int, num_input_tokens: int = 1, num_output_tokens: int = 1,
mode: str = 'linear',
freeze_qformer=False):
"""
:param mode: ['linear', 'transformer', 'qformer']
:param freeze_qformer: whether freeze the weights of qformer
"""
super().__init__()
self.num_input_tokens = num_input_tokens
self.num_output_tokens = num_output_tokens
self.mode = mode
self.out_dim = out_dim
if mode == 'linear':
self.model = nn.Linear(in_dim, out_dim)
elif mode == 'transformer':
hidden_dim = 512
self.fc = nn.Linear(in_dim, hidden_dim)
self.tfm = nn.Transformer(batch_first=True, norm_first=True,
d_model=hidden_dim, num_encoder_layers=4, num_decoder_layers=4,
dim_feedforward=hidden_dim * 4, dropout=0.0, nhead=4)
self.model = nn.Linear(hidden_dim, out_dim)
self.query_embs = nn.Parameter(torch.randn(1, num_output_tokens, hidden_dim))
elif mode == 'qformer':
# raise NotImplementedError(mode) # TODO: ADD Q-former FOR MAPPING LAYER
print('Loading Q-Former')
hidden_dim = 768
self.fc = nn.Linear(in_dim, hidden_dim)
self.Qformer, self.query_tokens = self.init_Qformer(
num_output_tokens, hidden_dim
)
self.Qformer.cls = None
self.Qformer.bert.embeddings.word_embeddings = None
self.Qformer.bert.embeddings.position_embeddings = None
for layer in self.Qformer.bert.encoder.layer:
layer.output = None
layer.intermediate = None
# self.load_from_pretrained(url_or_filename=q_former_model)
self.model = nn.Linear(hidden_dim, out_dim)
# if freeze_qformer:
# for name, param in self.Qformer.named_parameters():
# param.requires_grad = False
# self.Qformer = self.Qformer.eval()
# # self.Qformer.train = disabled_train
# self.query_tokens.requires_grad = False
# # logging.info("freeze Qformer")
print('Loading Q-Former Done')
else:
raise NotImplementedError(mode)
def forward(self, x: torch.Tensor, input_embs: torch.Tensor) -> torch.Tensor:
outputs = None
if isinstance(self.model, nn.ModuleList):
assert len(self.model) == x.shape[1] == self.num_input_tokens, (
len(self.model), x.shape, self.num_input_tokens)
outputs = []
for i in range(self.num_input_tokens):
outputs.append(self.model[i](x[:, i, :])) # (N, D)
outputs = torch.stack(outputs, dim=1) # (N, T_I_V_A.txt, D)
elif self.mode == 'transformer':
# print("x.size: ", x.size())
# print("input_embs.size: ", input_embs.size())
x = x + input_embs
# print('layer x: ', x)
x = self.fc(x)
# print('layer fc x: ', x)
x = self.tfm(x, self.query_embs.repeat(x.shape[0], 1, 1))
# print('layer tfm x: ', x)
outputs = self.model(x)
# print('layer tfm model: ', x)
if outputs.shape[1] != self.num_output_tokens and self.mode == 'linear':
if self.mode == 'linear':
outputs = outputs[:, :self.num_output_tokens, :]
else:
raise NotImplementedError
elif self.mode == 'qformer':
x = x + input_embs
x = self.fc(x)
image_atts = torch.ones(x.size()[:-1], dtype=torch.long).to(x.device)
# print(x.size())
query_tokens = self.query_tokens.expand(x.shape[0], -1, -1)
# print(image_atts.size())
# print(query_tokens.size())
outputs = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=x,
encoder_attention_mask=image_atts,
return_dict=True,
).last_hidden_state
# print(outputs.size())
outputs = self.model(outputs)
assert outputs.shape[1] == 1 or (outputs.shape[1] * outputs.shape[2] == self.num_output_tokens * self.out_dim), (
outputs.shape, self.num_output_tokens)
return outputs # (N, T_I_V_A.txt, D)