import tensorflow as tf | |
from transformers import PreTrainedModel, TFPreTrainedModel | |
from transformers.modeling_tf_utils import TFSequenceClassificationLoss | |
class TFNewsClassifier(TFPreTrainedModel, TFSequenceClassificationLoss): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
# Define your model architecture | |
self.lstm = tf.keras.layers.LSTM(128, return_sequences=True) | |
self.lstm2 = tf.keras.layers.LSTM(64) | |
self.dropout = tf.keras.layers.Dropout(0.5) | |
self.classifier = tf.keras.layers.Dense(self.num_labels, activation='softmax') | |
def call(self, inputs, training=False): | |
x = self.lstm(inputs) | |
x = self.dropout(x, training=training) | |
x = self.lstm2(x) | |
x = self.dropout(x, training=training) | |
return self.classifier(x) |