|
from typing import Dict, List, Any |
|
import torch |
|
import torch.nn as nn |
|
import json |
|
from transformers import pipeline, BertModel, AutoTokenizer, PretrainedConfig |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
self.model = CustomModel("test_bert_config.json") |
|
self.model.load_state_dict(torch.load("model3.pth")) |
|
|
|
def __call__(self, data: Dict[str, Any])-> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str`) |
|
date (:obj: `str`) |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
|
|
inputs = data.pop("inputs",data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prediction = self.model.classify(inputs) |
|
prediction = json.dumps(prediction) |
|
return prediction |
|
|
|
class CustomModel(nn.Module): |
|
def __init__(self, bert_config): |
|
super(CustomModel, self).__init__() |
|
|
|
self.bert = BertModel._from_config(PretrainedConfig.from_json_file(bert_config)) |
|
self.dropout = nn.Dropout(0.2) |
|
self.token_classifier = nn.Linear(self.bert.config.hidden_size, 16) |
|
self.sequence_classifier = nn.Linear(self.bert.config.hidden_size, 7) |
|
|
|
|
|
nn.init.kaiming_normal_(self.token_classifier.weight, mode='fan_in', nonlinearity='linear') |
|
nn.init.kaiming_normal_(self.sequence_classifier.weight, mode='fan_in', nonlinearity='linear') |
|
self.seq_labels = [ |
|
"Transaction", |
|
"Courier", |
|
"OTP", |
|
"Expiry", |
|
"Misc", |
|
"Tele Marketing", |
|
"Spam", |
|
] |
|
|
|
self.token_class_labels = [ |
|
'O', |
|
'Courier Service', |
|
'Credit', |
|
'Date', |
|
'Debit', |
|
'Email', |
|
'Expiry', |
|
'Item', |
|
'Order ID', |
|
'Organization', |
|
'OTP', |
|
'Phone Number', |
|
'Refund', |
|
'Time', |
|
'Tracking ID', |
|
'URL', |
|
] |
|
base_model_path = '.' |
|
self.tokenizer = AutoTokenizer.from_pretrained(base_model_path) |
|
|
|
def forward(self, input_ids : torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor): |
|
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) |
|
sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output |
|
|
|
token_logits = self.token_classifier(self.dropout(sequence_output)) |
|
sequence_logits = self.sequence_classifier(self.dropout(pooled_output)) |
|
|
|
return token_logits, sequence_logits |
|
def classify(self, inputs): |
|
out = self.tokenizer(inputs, return_tensors="pt") |
|
token_classification_logits, sequence_logits = self.forward(**out) |
|
token_classification_logits = token_classification_logits.argmax(2)[0] |
|
sequence_logits = sequence_logits.argmax(1)[0] |
|
token_classification_out = [self.token_class_labels[i] for i in token_classification_logits.tolist()] |
|
seq_classification_out = self.seq_labels[sequence_logits] |
|
|
|
return {"token_classfier":token_classification_out, "sequence_classfier": seq_classification_out} |