Spaces:
Sleeping
Sleeping
File size: 2,805 Bytes
477927e 411fb81 477927e 411fb81 477927e 411fb81 477927e 411fb81 477927e 411fb81 477927e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from transformers import AutoTokenizer, EncoderDecoderModel
from transformers import pipeline as hf_pipeline
from pathlib import Path
import re
class NpcBertGPT2():
def __init__(self):
self.model = None
self.tokenizer = None
self.pipeline = None
# relative to app.py
self.pretrained_model = "./models/npc-bert-gpt2-best"
def load(self):
"""Loads the fine-tuned EncoderDecoder model and related components.
This method initializes the model, tokenizer, and pipeline for the
report conclusion generation task using the pre-trained weights from the
specified directory.
Raises:
FileNotFoundError: If the pretrained model directory is not found.
"""
if not Path(self.pretrained_model).is_dir():
raise FileNotFoundError(f"Cannot found pretrained model at: {self.pretrained_model}")
self.model = EncoderDecoderModel.from_pretrained(self.pretrained_model)
self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model)
self.pipeline = hf_pipeline("text2text-generation",
model=self.model,
tokenizer=self.tokenizer,
device='cpu',
num_beams=4,
do_sample=True,
top_k = 5,
temperature=.95,
early_stopping=True,
no_repeat_ngram_size=5,
max_new_tokens=60)
def __call__(self, *args):
"""Performs masked language modeling prediction.
This method should be called only after the `load` method has been executed
to ensure that the model and pipeline are properly initialized. It accepts
arguments to pass to the Hugging Face fill-mask pipeline.
Args:
*args: Variable length argument list to pass to the pipeline.
Returns:
The output of the fill-mask pipeline.
Raises:
BrokenPipeError: If the model has not been loaded before calling this method.
"""
if self.pipeline is None:
msg = "Model was not initialized, have you run load()?"
raise BrokenPipeError(msg)
pipe_out, = self.pipeline(*args)
pipe_out = pipe_out['generated_text']
# remove repeated lines by hard coding
mo = re.search("\. (questionable|anterio|zius)", pipe_out)
if mo is not None:
end_sig = mo.start()
pipe_out = pipe_out[:end_sig + 1]
return pipe_out
|