MisRoberta_lstm / model.py
UNCANNY69's picture
Update model.py
ac8dddf verified
raw
history blame
No virus
2.66 kB
from transformers import PretrainedConfig
from abc import ABCMeta
import torch
from transformers.pytorch_utils import nn
from transformers import RobertaModel, RobertaConfig
import torch
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import PretrainedConfig
class RoBERTaLSTMConfig(PretrainedConfig):
model_type = "robertaLSTMForSequenceClassification"
def __init__(self,
num_classes=2,
embed_dim=768,
num_layers=12,
hidden_dim_lstm=256, # New parameter for LSTM
dropout_rate=0.1,
**kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
self.embed_dim = embed_dim
self.num_layers = num_layers
self.hidden_dim_lstm = hidden_dim_lstm # Assign LSTM hidden dimension
self.dropout_rate = dropout_rate
self.id2label = {
0: "fake",
1: "true",
}
self.label2id = {
"fake": 0,
"true": 1,
}
class RoBERTaLSTMForSequenceClassification(PreTrainedModel, metaclass=ABCMeta):
config_class = RoBERTaLSTMConfig
def __init__(self, config):
super(RoBERTaLSTMForSequenceClassification, self).__init__(config)
self.num_classes = config.num_classes
self.embed_dim = config.embed_dim
self.num_layers = config.num_layers
self.hidden_dim_lstm = config.hidden_dim_lstm
self.dropout = nn.Dropout(config.dropout_rate)
self.roberta = RobertaModel.from_pretrained('roberta-base',output_hidden_states=True,
output_attentions=False)
print("RoBERTa Model Loaded")
self.lstm = nn.LSTM(self.embed_dim, self.hidden_dim_lstm, batch_first=True, num_layers=3)
self.fc = nn.Linear(self.hidden_dim_lstm, self.num_classes)
def forward(self, input_ids, attention_mask, labels=None):
roberta_output = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
hidden_states = roberta_output.last_hidden_state
out, _ = self.lstm(hidden_states)
out = self.dropout(out[:, -1, :])
logits = self.fc(out)
loss = None
if labels is not None:
loss = F.cross_entropy(logits, labels)
out = SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=roberta_output.hidden_states,
attentions=roberta_output.attentions,
)
return out