from transformers import AutoTokenizer, EncoderDecoderModel from transformers import pipeline as hf_pipeline from pathlib import Path import re from .app_logger import get_logger class NpcBertGPT2(): logger = get_logger() def __init__(self): self.model = None self.tokenizer = None self.pipeline = None # relative to app.py self.pretrained_model = "./models/npc-bert-gpt2-best" self.logger.info(f"Created {__class__.__name__} instance.") def load(self): """Loads the fine-tuned EncoderDecoder model and related components. This method initializes the model, tokenizer, and pipeline for the report conclusion generation task using the pre-trained weights from the specified directory. Raises: FileNotFoundError: If the pretrained model directory is not found. """ if not Path(self.pretrained_model).is_dir(): raise FileNotFoundError(f"Cannot found pretrained model at: {self.pretrained_model}") self.model = EncoderDecoderModel.from_pretrained(self.pretrained_model) self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model) self.pipeline = hf_pipeline("text2text-generation", model=self.model, tokenizer=self.tokenizer, device='cpu', num_beams=4, do_sample=True, top_k = 5, temperature=.95, early_stopping=True, no_repeat_ngram_size=5, max_new_tokens=60) def __call__(self, *args): """Performs masked language modeling prediction. This method should be called only after the `load` method has been executed to ensure that the model and pipeline are properly initialized. It accepts arguments to pass to the Hugging Face fill-mask pipeline. Args: *args: Variable length argument list to pass to the pipeline. Returns: The output of the fill-mask pipeline. Raises: BrokenPipeError: If the model has not been loaded before calling this method. """ if self.pipeline is None: msg = "Model was not initialized, have you run load()?" raise BrokenPipeError(msg) self.logger.info(f"Called with arguments {args = }") pipe_out, = self.pipeline(*args) pipe_out = pipe_out['generated_text'] self.logger.info(f"Generated text: {pipe_out}") # remove repeated lines by hard coding mo = re.search("\. (questionable|anterio|zius)", pipe_out) if mo is not None: end_sig = mo.start() pipe_out = pipe_out[:end_sig + 1] self.logger.info(f"Displayed text: {pipe_out}") return pipe_out