| |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoModel, PreTrainedModel, AutoConfig |
|
|
| class HybridRegressionModel(PreTrainedModel): |
| """ |
| A hybrid model that combines a transformer base with additional numerical features. |
| The output is a single regression value. This architecture MUST match the one |
| used to create the checkpoint. |
| """ |
| |
| config_class = AutoConfig |
|
|
| def __init__(self, config, num_extra_features=7): |
| super(HybridRegressionModel, self).__init__(config) |
| |
| self.transformer = AutoModel.from_pretrained(config._name_or_path, config=config) |
|
|
| |
| |
| self.regressor = nn.Linear(self.transformer.config.hidden_size + num_extra_features, 1) |
|
|
| def forward(self, input_ids, attention_mask, extra_features, labels=None): |
| |
| outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask) |
|
|
| |
| pooler_output = outputs.pooler_output |
|
|
| |
| combined_features = torch.cat((pooler_output, extra_features), dim=1) |
|
|
| |
| logits = self.regressor(combined_features) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = nn.MSELoss() |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
|
| return (loss, logits) if loss is not None else logits |
|
|