from typing import Dict, List, Any import torch import torch.nn as nn import json from transformers import pipeline, BertModel, AutoTokenizer, PretrainedConfig class EndpointHandler(): def __init__(self, path=""): # self.pipeline = pipeline("text-classification",model=path) self.model = CustomModel("test_bert_config.json") self.model.load_state_dict(torch.load("model3.pth")) def __call__(self, data: Dict[str, Any])-> List[Dict[str, Any]]: """ data args: inputs (:obj: `str`) date (:obj: `str`) Return: A :obj:`list` | `dict`: will be serialized and returned """ # get inputs inputs = data.pop("inputs",data) # date = data.pop("date", None) # check if date exists and if it is a holiday # if date is not None and date in self.holidays: # return [{"label": "happy", "score": 1}] # run normal prediction prediction = self.model.classify(inputs) prediction = json.dumps(prediction) return prediction class CustomModel(nn.Module): def __init__(self, bert_config): super(CustomModel, self).__init__() # self.bert = BertModel.from_pretrained(base_model_path) self.bert = BertModel._from_config(PretrainedConfig.from_json_file(bert_config)) self.dropout = nn.Dropout(0.2) self.token_classifier = nn.Linear(self.bert.config.hidden_size, 16) self.sequence_classifier = nn.Linear(self.bert.config.hidden_size, 7) # Initialize weights nn.init.kaiming_normal_(self.token_classifier.weight, mode='fan_in', nonlinearity='linear') nn.init.kaiming_normal_(self.sequence_classifier.weight, mode='fan_in', nonlinearity='linear') self.seq_labels = [ "Transaction", "Courier", "OTP", "Expiry", "Misc", "Tele Marketing", "Spam", ] self.token_class_labels = [ 'O', 'Courier Service', 'Credit', 'Date', 'Debit', 'Email', 'Expiry', 'Item', 'Order ID', 'Organization', 'OTP', 'Phone Number', 'Refund', 'Time', 'Tracking ID', 'URL', ] base_model_path = '.' self.tokenizer = AutoTokenizer.from_pretrained(base_model_path) def forward(self, input_ids : torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output token_logits = self.token_classifier(self.dropout(sequence_output)) sequence_logits = self.sequence_classifier(self.dropout(pooled_output)) return token_logits, sequence_logits def classify(self, inputs): out = self.tokenizer(inputs, return_tensors="pt") token_classification_logits, sequence_logits = self.forward(**out) token_classification_logits = token_classification_logits.argmax(2)[0] sequence_logits = sequence_logits.argmax(1)[0] token_classification_out = [self.token_class_labels[i] for i in token_classification_logits.tolist()] seq_classification_out = self.seq_labels[sequence_logits] # return token_classification_out, seq_classification_out return {"token_classfier":token_classification_out, "sequence_classfier": seq_classification_out}