emo-detector-space / hf_model /modeling_bert_ffnn.py
23f2001106
Initial commit
e877829
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel
from .configuration_bert_ffnn import BertFFNNConfig
class AttentionPooling(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.attention = nn.Linear(hidden_size, 1)
def forward(self, hidden_states, attention_mask):
scores = self.attention(hidden_states).squeeze(-1)
scores = scores.masked_fill(attention_mask == 0, -1e9)
weights = torch.softmax(scores, dim=-1)
return torch.sum(hidden_states * weights.unsqueeze(-1), dim=1)
class BERT_FFNN(PreTrainedModel):
config_class = BertFFNNConfig
base_model_prefix = "bert_ffnn"
def __init__(self, config):
super().__init__(config)
self.bert = AutoModel.from_pretrained(config.bert_model_name)
self.pooling = config.pooling
self.use_layer_norm = config.use_layer_norm
if self.pooling == "attention":
self.attention_pool = AttentionPooling(self.bert.config.hidden_size)
if config.freeze_bert:
for p in self.bert.parameters():
p.requires_grad = False
elif config.freeze_layers > 0:
for layer in self.bert.encoder.layer[:config.freeze_layers]:
for p in layer.parameters():
p.requires_grad = False
layers = []
in_dim = self.bert.config.hidden_size
for h_dim in config.hidden_dims:
layers.append(nn.Linear(in_dim, h_dim))
layers.append(nn.ReLU())
if config.use_layer_norm:
layers.append(nn.LayerNorm(h_dim))
layers.append(nn.Dropout(config.dropout))
in_dim = h_dim
layers.append(nn.Linear(in_dim, config.output_dim))
self.classifier = nn.Sequential(*layers)
self.post_init()
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
if self.pooling == "mean":
mask = attention_mask.unsqueeze(-1).float()
sum_emb = (outputs.last_hidden_state * mask).sum(1)
features = sum_emb / mask.sum(1).clamp(min=1e-9)
elif self.pooling == "max":
mask = attention_mask.unsqueeze(-1).float()
masked_emb = outputs.last_hidden_state.masked_fill(mask == 0, float('-inf'))
features, _ = masked_emb.max(dim=1)
elif self.pooling == "attention":
features = self.attention_pool(outputs.last_hidden_state, attention_mask)
else: # CLS pooling
features = (
outputs.pooler_output
if getattr(outputs, "pooler_output", None) is not None
else outputs.last_hidden_state[:, 0]
)
logits = self.classifier(features)
return logits