|
|
import torch |
|
|
import torch.nn as nn |
|
|
from .configuration_smsbert import SMSBertConfig |
|
|
from transformers import pipeline, BertModel, AutoTokenizer, PretrainedConfig,PreTrainedModel, Pipeline, AutoModel,AutoModelForSequenceClassification, BertConfig |
|
|
class SMSBertModel(PreTrainedModel): |
|
|
config_class = SMSBertConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
|
|
|
self.bert = BertModel._from_config(config) |
|
|
self.dropout = nn.Dropout(0.2) |
|
|
self.token_classifier = nn.Linear(self.bert.config.hidden_size, 16) |
|
|
self.sequence_classifier = nn.Linear(self.bert.config.hidden_size, 7) |
|
|
|
|
|
|
|
|
nn.init.kaiming_normal_(self.token_classifier.weight, mode='fan_in', nonlinearity='linear') |
|
|
nn.init.kaiming_normal_(self.sequence_classifier.weight, mode='fan_in', nonlinearity='linear') |
|
|
self.seq_labels = [ |
|
|
"Transaction", |
|
|
"Courier", |
|
|
"OTP", |
|
|
"Expiry", |
|
|
"Misc", |
|
|
"Tele Marketing", |
|
|
"Spam", |
|
|
] |
|
|
|
|
|
self.token_class_labels = [ |
|
|
'O', |
|
|
'Courier Service', |
|
|
'Credit', |
|
|
'Date', |
|
|
'Debit', |
|
|
'Email', |
|
|
'Expiry', |
|
|
'Item', |
|
|
'Order ID', |
|
|
'Organization', |
|
|
'OTP', |
|
|
'Phone Number', |
|
|
'Refund', |
|
|
'Time', |
|
|
'Tracking ID', |
|
|
'URL', |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, input_ids,attention_mask, token_type_ids): |
|
|
|
|
|
|
|
|
outputs = self.bert(input_ids, attention_mask, token_type_ids) |
|
|
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output |
|
|
|
|
|
token_classification_logits = self.token_classifier(self.dropout(sequence_output)) |
|
|
sequence_logits = self.sequence_classifier(self.dropout(pooled_output)) |
|
|
|
|
|
token_classification_logits = token_classification_logits.argmax(2)[0] |
|
|
sequence_logits = sequence_logits.argmax(1)[0] |
|
|
token_classification_out = [self.token_class_labels[i] for i in token_classification_logits.tolist()] |
|
|
seq_classification_out = self.seq_labels[sequence_logits] |
|
|
model_out = str({"token_classfier":token_classification_out, "sequence_classfier": seq_classification_out}) |
|
|
return model_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|