zdou0830's picture
desco
749745d
raw history blame
No virus
3.25 kB
from copy import deepcopy
import numpy as np
import torch
from torch import nn
# from pytorch_pretrained_bert.modeling import BertModel
from transformers import BertConfig, RobertaConfig, RobertaModel, BertModel
class BertEncoder(nn.Module):
def __init__(self, cfg):
super(BertEncoder, self).__init__()
self.cfg = cfg
self.bert_name = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE
print("LANGUAGE BACKBONE USE GRADIENT CHECKPOINTING: ", self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT)
if self.bert_name == "bert-base-uncased":
config = BertConfig.from_pretrained(self.bert_name)
config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT
self.model = BertModel.from_pretrained(self.bert_name, add_pooling_layer=False, config=config)
self.language_dim = 768
elif self.bert_name == "roberta-base":
config = RobertaConfig.from_pretrained(self.bert_name)
config.gradient_checkpointing = self.cfg.MODEL.LANGUAGE_BACKBONE.USE_CHECKPOINT
self.model = RobertaModel.from_pretrained(self.bert_name, add_pooling_layer=False, config=config)
self.language_dim = 768
else:
raise NotImplementedError
self.num_layers = cfg.MODEL.LANGUAGE_BACKBONE.N_LAYERS
def forward(self, x):
input = x["input_ids"]
mask = x["attention_mask"]
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
# with padding, always 256
outputs = self.model(
input_ids=input,
attention_mask=mask,
output_hidden_states=True,
)
# outputs has 13 layers, 1 input layer and 12 hidden layers
encoded_layers = outputs.hidden_states[1:]
features = None
features = torch.stack(encoded_layers[-self.num_layers :], 1).mean(1)
# language embedding has shape [len(phrase), seq_len, language_dim]
features = features / self.num_layers
embedded = features * mask.unsqueeze(-1).float()
aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())
else:
# without padding, only consider positive_tokens
max_len = (input != 0).sum(1).max().item()
outputs = self.model(
input_ids=input[:, :max_len],
attention_mask=mask[:, :max_len],
output_hidden_states=True,
)
# outputs has 13 layers, 1 input layer and 12 hidden layers
encoded_layers = outputs.hidden_states[1:]
features = None
features = torch.stack(encoded_layers[-self.num_layers :], 1).mean(1)
# language embedding has shape [len(phrase), seq_len, language_dim]
features = features / self.num_layers
embedded = features * mask[:, :max_len].unsqueeze(-1).float()
aggregate = embedded.sum(1) / (mask.sum(-1).unsqueeze(-1).float())
ret = {"aggregate": aggregate, "embedded": embedded, "masks": mask, "hidden": encoded_layers[-1]}
return ret