|
import tensorflow as tf
|
|
|
|
from tensorflow.keras.models import load_model
|
|
from transformers import TFBertModel
|
|
|
|
class MultiTaskModel:
|
|
"""
|
|
A class used to represent a MultiTaskModel, which classifies the emotion and toxicity of Valorant chat messages
|
|
"""
|
|
def __init__(self, is_multilingual = False, preprocessor = None):
|
|
if is_multilingual:
|
|
pre_model = TFBertModel.from_pretrained('bert-base-multilingual-cased')
|
|
self.model = load_model('model_with_bert_multilingual.h5', custom_objects={'TFBertModel': TFBertModel})
|
|
else:
|
|
pre_model = TFBertModel.from_pretrained('bert-base-uncased')
|
|
self.model = load_model('model_with_bert_base.h5', custom_objects={'TFBertModel': pre_model})
|
|
self.load_preprocess(preprocessor)
|
|
|
|
def load_preprocess(self, prep):
|
|
self.preprocessor = prep
|
|
|
|
def predict(self, text):
|
|
preptext= self.preprocessor.preprocess_text(text)
|
|
return self.model.predict(preptext)
|
|
|
|
def decode(self, pred):
|
|
return self.preprocessor.decoder.toxicity(pred[1]), self.preprocessor.decoder.emotion(pred[0]) |