Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel, AutoModel | |
| from .configuration_bert_ffnn import BertFFNNConfig | |
| class AttentionPooling(nn.Module): | |
| def __init__(self, hidden_size): | |
| super().__init__() | |
| self.attention = nn.Linear(hidden_size, 1) | |
| def forward(self, hidden_states, attention_mask): | |
| scores = self.attention(hidden_states).squeeze(-1) | |
| scores = scores.masked_fill(attention_mask == 0, -1e9) | |
| weights = torch.softmax(scores, dim=-1) | |
| return torch.sum(hidden_states * weights.unsqueeze(-1), dim=1) | |
| class BERT_FFNN(PreTrainedModel): | |
| config_class = BertFFNNConfig | |
| base_model_prefix = "bert_ffnn" | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.bert = AutoModel.from_pretrained(config.bert_model_name) | |
| self.pooling = config.pooling | |
| self.use_layer_norm = config.use_layer_norm | |
| if self.pooling == "attention": | |
| self.attention_pool = AttentionPooling(self.bert.config.hidden_size) | |
| if config.freeze_bert: | |
| for p in self.bert.parameters(): | |
| p.requires_grad = False | |
| elif config.freeze_layers > 0: | |
| for layer in self.bert.encoder.layer[:config.freeze_layers]: | |
| for p in layer.parameters(): | |
| p.requires_grad = False | |
| layers = [] | |
| in_dim = self.bert.config.hidden_size | |
| for h_dim in config.hidden_dims: | |
| layers.append(nn.Linear(in_dim, h_dim)) | |
| layers.append(nn.ReLU()) | |
| if config.use_layer_norm: | |
| layers.append(nn.LayerNorm(h_dim)) | |
| layers.append(nn.Dropout(config.dropout)) | |
| in_dim = h_dim | |
| layers.append(nn.Linear(in_dim, config.output_dim)) | |
| self.classifier = nn.Sequential(*layers) | |
| self.post_init() | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| if self.pooling == "mean": | |
| mask = attention_mask.unsqueeze(-1).float() | |
| sum_emb = (outputs.last_hidden_state * mask).sum(1) | |
| features = sum_emb / mask.sum(1).clamp(min=1e-9) | |
| elif self.pooling == "max": | |
| mask = attention_mask.unsqueeze(-1).float() | |
| masked_emb = outputs.last_hidden_state.masked_fill(mask == 0, float('-inf')) | |
| features, _ = masked_emb.max(dim=1) | |
| elif self.pooling == "attention": | |
| features = self.attention_pool(outputs.last_hidden_state, attention_mask) | |
| else: # CLS pooling | |
| features = ( | |
| outputs.pooler_output | |
| if getattr(outputs, "pooler_output", None) is not None | |
| else outputs.last_hidden_state[:, 0] | |
| ) | |
| logits = self.classifier(features) | |
| return logits | |