cooldragon12's picture
upload app file
acf980a verified
raw
history blame
No virus
1.34 kB
from transformers import BertTokenizer
class Decoder:
def __init__(self):
import pickle
with open('pipeline/preprocessing/encoder_toxicity.pkl', 'rb') as f:
self.__encoder_toxicity = pickle.load(f)
with open('pipeline/preprocessing/encoder_emotion.pkl', 'rb') as f:
self.__encoder_emotion = pickle.load(f)
# Decoding one-hot encoded labels
def toxicity(self,pred):
return self.__encoder_toxicity.inverse_transform(pred)
def emotion(self,pred):
return self.__encoder_emotion.inverse_transform(pred)
class Preprocessor:
"""A class used to represent a Preprocessor, which preprocesses text data for the model"""
def __init__(self, is_multilingual = False):
if is_multilingual:
self.tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
else:
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.decoder = Decoder()
"""Added a decoder object to the Preprocessor class to decode the one-hot encoded labels"""
def preprocess_text(self,text):
return self.tokenizer.encode(text,add_special_tokens=True, max_length=65,
padding="max_length", truncation=True, return_attention_mask=False, return_tensors='tf')