from transformers import PretrainedConfig from abc import ABCMeta import torch from transformers.pytorch_utils import nn from transformers import RobertaModel, RobertaConfig import torch import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput from transformers import PretrainedConfig class RoBERTaLSTMConfig(PretrainedConfig): model_type = "robertaLSTMForSequenceClassification" def __init__(self, num_classes=2, embed_dim=768, num_layers=12, hidden_dim_lstm=256, # New parameter for LSTM dropout_rate=0.1, **kwargs): super().__init__(**kwargs) self.num_classes = num_classes self.embed_dim = embed_dim self.num_layers = num_layers self.hidden_dim_lstm = hidden_dim_lstm # Assign LSTM hidden dimension self.dropout_rate = dropout_rate self.id2label = { 0: "fake", 1: "true", } self.label2id = { "fake": 0, "true": 1, } class RoBERTaLSTMForSequenceClassification(PreTrainedModel, metaclass=ABCMeta): config_class = RoBERTaLSTMConfig def __init__(self, config): super(RoBERTaLSTMForSequenceClassification, self).__init__(config) self.num_classes = config.num_classes self.embed_dim = config.embed_dim self.num_layers = config.num_layers self.hidden_dim_lstm = config.hidden_dim_lstm self.dropout = nn.Dropout(config.dropout_rate) self.roberta = RobertaModel.from_pretrained('roberta-base',output_hidden_states=True, output_attentions=False) print("RoBERTa Model Loaded") self.lstm = nn.LSTM(self.embed_dim, self.hidden_dim_lstm, batch_first=True, num_layers=3) self.fc = nn.Linear(self.hidden_dim_lstm, self.num_classes) def forward(self, input_ids, attention_mask, labels=None): roberta_output = self.roberta(input_ids=input_ids, attention_mask=attention_mask) hidden_states = roberta_output.last_hidden_state out, _ = self.lstm(hidden_states) out = self.dropout(out[:, -1, :]) logits = self.fc(out) loss = None if labels is not None: loss = F.cross_entropy(logits, labels) out = SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=roberta_output.hidden_states, attentions=roberta_output.attentions, ) return out