cooldragon12's picture
upload app file
acf980a verified
raw
history blame
No virus
1.16 kB
import tensorflow as tf
from tensorflow.keras.models import load_model
from transformers import TFBertModel
class MultiTaskModel:
"""
A class used to represent a MultiTaskModel, which classifies the emotion and toxicity of Valorant chat messages
"""
def __init__(self, is_multilingual = False, preprocessor = None):
if is_multilingual:
pre_model = TFBertModel.from_pretrained('bert-base-multilingual-cased')
self.model = load_model('model_with_bert_multilingual.h5', custom_objects={'TFBertModel': TFBertModel})
else:
pre_model = TFBertModel.from_pretrained('bert-base-uncased')
self.model = load_model('model_with_bert_base.h5', custom_objects={'TFBertModel': pre_model})
self.load_preprocess(preprocessor)
def load_preprocess(self, prep):
self.preprocessor = prep
def predict(self, text):
preptext= self.preprocessor.preprocess_text(text)
return self.model.predict(preptext)
def decode(self, pred):
return self.preprocessor.decoder.toxicity(pred[1]), self.preprocessor.decoder.emotion(pred[0])