dlsmallw's picture
Task-290 Integrate use of models within the application for use in inference
82221ca
raw
history blame
3.9 kB
"""
Script file used for performing inference with an existing model.
"""
from pathlib import Path
import torch
import json
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification
)
## Class used to encapsulate and handle the logic for inference
class InferenceHandler:
def __init__(self, bin_model_path: Path, ml_regr_model_path: Path):
self.bin_tokenizer, self.bin_model = self.init_model_and_tokenizer(bin_model_path)
self.ml_regr_tokenizer, self.ml_regr_model = self.init_model_and_tokenizer(ml_regr_model_path)
## Initializes a model and tokenizer for use in inference using the models path
def init_model_and_tokenizer(self, model_path: Path):
with open(model_path / 'config.json') as config_file:
config_json = json.load(config_file)
model_name = config_json['_name_or_path']
model_type = config_json['model_type']
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_path, model_type=model_type)
model.eval()
return tokenizer, model
## Handles logic used to encode the text for use in binary classification
def encode_binary(self, text):
bin_tokenized_input = self.bin_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
return bin_tokenized_input
## Handles logic used to encode the text for use in multilabel regression
def encode_multilabel(self, text):
ml_tokenized_input = self.ml_regr_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
return ml_tokenized_input
## Handles text encoding for both binary classification and multilabel regression
def encode_input(self, text):
bin_inputs = self.encode_binary(text)
ml_inputs = self.encode_multilabel(text)
return bin_inputs, ml_inputs
## Handles performing the full sentiment analysis (binary classification and multilabel regression)
def classify_text(self, text):
res_obj = {
'raw_text': text,
'text_sentiment': None,
'numerical_sentiment': None,
'category_sentiments': {
'Gender': None,
'Race': None,
'Sexuality': None,
'Disability': None,
'Religion': None,
'Unspecified': None
}
}
text_prediction, pred_class = self.discriminatory_inference(text)
res_obj['text_sentiment'] = text_prediction
res_obj['numerical_sentiment'] = pred_class
if pred_class == 1:
ml_infer_results = self.category_inference(text)
for idx, key in enumerate(res_obj['category_sentiments'].keys()):
res_obj['category_sentiments'][key] = ml_infer_results[idx]
return res_obj
## Handles logic for checking the binary classfication of the text
def discriminatory_inference(self, text):
bin_inputs = self.encode_binary(text)
with torch.no_grad():
bin_logits = self.bin_model(**bin_inputs).logits
probs = torch.nn.functional.softmax(bin_logits, dim=-1)
pred_class = torch.argmax(probs).item()
bin_label_map = {0: "Non-Discriminatory", 1: "Discriminatory"}
bin_text_pred = bin_label_map[pred_class]
return bin_text_pred, pred_class
## Handles logic for assessing the categories of discrimination
def category_inference(self, text):
ml_inputs = self.encode_multilabel(text)
with torch.no_grad():
ml_outputs = self.ml_regr_model(**ml_inputs).logits
ml_op_list = ml_outputs.squeeze().tolist()
results = []
for item in ml_op_list:
results.append(max(0.0, item))
return results