TextManipulation / textSFunctionality.py
sahibnanda's picture
commit
d9579cb
raw
history blame
893 Bytes
import re
import os
import tensorflow as tf
import keras
import keras_nlp
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MODEL_PATH = r"TextSummarizationModel"
WEIGHT_PATH = r"new_model.weights.h5"
WEIGHT_PATH = os.path.join(MODEL_PATH, WEIGHT_PATH)
def cleanText(text):
text = str(text)
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
text = text.lower()
return text
preprocessor = keras_nlp.models.BartSeq2SeqLMPreprocessor.from_preset(MODEL_PATH, encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,)
model = keras_nlp.models.BartSeq2SeqLM.from_preset(MODEL_PATH, preprocessor=preprocessor)
model.load_weights(WEIGHT_PATH)
def generateText(input_text, model=model, max_length=200):
input_text = cleanText(input_text)
output = model.generate(input_text, max_length=max_length)
return output