FatimahEmadEldin's picture
Upload fine-tuned hybrid document readability model
d5cbd37 verified
# This file defines the custom architecture for your document-level hybrid model.
import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel, AutoConfig
class HybridRegressionModel(PreTrainedModel):
"""
A hybrid model that combines a transformer base with additional numerical features.
The output is a single regression value. This architecture MUST match the one
used to create the checkpoint.
"""
# This associates the model with the base configuration class
config_class = AutoConfig
def __init__(self, config, num_extra_features=7):
super(HybridRegressionModel, self).__init__(config)
# Load the transformer body from the configuration
self.transformer = AutoModel.from_pretrained(config._name_or_path, config=config)
# Define the custom regression head. This is simpler than the other model.
# It takes the transformer's pooled output + extra features.
self.regressor = nn.Linear(self.transformer.config.hidden_size + num_extra_features, 1)
def forward(self, input_ids, attention_mask, extra_features, labels=None):
# Pass inputs through the transformer body
outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
# Use the pooler_output for the sequence representation
pooler_output = outputs.pooler_output
# Concatenate transformer output with the numerical features
combined_features = torch.cat((pooler_output, extra_features), dim=1)
# Get the final prediction (logit) from the regressor
logits = self.regressor(combined_features)
loss = None
if labels is not None:
loss_fct = nn.MSELoss()
loss = loss_fct(logits.squeeze(), labels.squeeze())
return (loss, logits) if loss is not None else logits