from tensorflow.keras.models import load_model | |
from transformers import TFBertModel | |
class MultiTaskModel: | |
""" | |
A class used to represent a MultiTaskModel, which classifies the emotion and toxicity of Valorant chat messages | |
""" | |
def __init__(self, is_multilingual = False, preprocessor = None): | |
if is_multilingual: | |
pre_model = TFBertModel.from_pretrained('bert-base-multilingual-cased') | |
self.model = load_model('model_with_bert_multilingual.h5', custom_objects={'TFBertModel': TFBertModel}) | |
else: | |
pre_model = TFBertModel.from_pretrained('bert-base-uncased') | |
self.model = load_model('model_with_bert_base.h5', custom_objects={'TFBertModel': pre_model}) | |
self.load_preprocess(preprocessor) | |
def load_preprocess(self, prep): | |
self.preprocessor = prep | |
def predict(self, text): | |
preptext= self.preprocessor.preprocess_text(text) | |
print(self.model) | |
return self.model.predict(preptext) | |
def decode(self, pred): | |
return self.preprocessor.decoder.toxicity(pred[1]), self.preprocessor.decoder.emotion(pred[0]) |