npc-bert-demo / npc_bert_models /summary_module.py
mlwong's picture
Fix typo
91c6989
raw
history blame contribute delete
No virus
3.11 kB
from transformers import AutoTokenizer, EncoderDecoderModel
from transformers import pipeline as hf_pipeline
from pathlib import Path
import re
from .app_logger import get_logger
class NpcBertGPT2():
logger = get_logger()
def __init__(self):
self.model = None
self.tokenizer = None
self.pipeline = None
# relative to app.py
self.pretrained_model = "./models/npc-bert-gpt2-best"
self.logger.info(f"Created {__class__.__name__} instance.")
def load(self):
"""Loads the fine-tuned EncoderDecoder model and related components.
This method initializes the model, tokenizer, and pipeline for the
report conclusion generation task using the pre-trained weights from the
specified directory.
Raises:
FileNotFoundError: If the pretrained model directory is not found.
"""
if not Path(self.pretrained_model).is_dir():
raise FileNotFoundError(f"Cannot found pretrained model at: {self.pretrained_model}")
self.model = EncoderDecoderModel.from_pretrained(self.pretrained_model)
self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model)
self.pipeline = hf_pipeline("text2text-generation",
model=self.model,
tokenizer=self.tokenizer,
device='cpu',
num_beams=4,
do_sample=True,
top_k = 5,
temperature=.95,
early_stopping=True,
no_repeat_ngram_size=5,
max_new_tokens=60)
def __call__(self, *args):
"""Performs masked language modeling prediction.
This method should be called only after the `load` method has been executed
to ensure that the model and pipeline are properly initialized. It accepts
arguments to pass to the Hugging Face fill-mask pipeline.
Args:
*args: Variable length argument list to pass to the pipeline.
Returns:
The output of the fill-mask pipeline.
Raises:
BrokenPipeError: If the model has not been loaded before calling this method.
"""
if self.pipeline is None:
msg = "Model was not initialized, have you run load()?"
raise BrokenPipeError(msg)
self.logger.info(f"Called with arguments {args = }")
pipe_out, = self.pipeline(*args)
pipe_out = pipe_out['generated_text']
self.logger.info(f"Generated text: {pipe_out}")
# remove repeated lines by hard coding
mo = re.search("\. (questionable|anterio|zius)", pipe_out)
if mo is not None:
end_sig = mo.start()
pipe_out = pipe_out[:end_sig + 1]
self.logger.info(f"Displayed text: {pipe_out}")
return pipe_out