mlwong's picture
Add logger
d718096
raw
history blame contribute delete
No virus
3.29 kB
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import pipeline as hf_pipeline
from pathlib import Path
from typing import Any, Dict
from .app_logger import get_logger
class NpcBertCLS():
r"""A class for performing report classification with BERT.
This class facilitates report classification tasks using a BERT model
fine-tuned on NPC staging reports. The base model is an uncased model
released by Microsoft, which can be found on the Hugging Face model hub
under the name 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'.
Attributes:
model (transformers.PreTrainedModel):
The fine-tuned BERT model for sequence classification.
tokenizer (transformers.PreTrainedTokenizer):
The tokenizer for the BERT model.
pipeline (transformers.text-classification):
The Hugging Face text-classification pipeline.
pretrained_model (str):
The path to the directory containing the fine-tuned model.
"""
logger = get_logger()
def __init__(self):
self.model = None
self.tokenizer = None
self.pipeline = None
# relative to app.py
self.pretrained_model = "./models/npc-bert-cls"
self.logger.info(f"Created {__class__.__name__} instance.")
def load(self) -> None:
"""Loads the fine-tuned BERT model and related components.
This method initializes the model, tokenizer, and pipeline for the
text classification tasks using the pre-trained weights from the
specified directory.
Raises:
FileNotFoundError: If the pretrained model directory is not found.
"""
if not Path(self.pretrained_model).is_dir():
raise FileNotFoundError(f"Cannot found pretrained model at: {self.pretrained_model}")
self.model = AutoModelForSequenceClassification.from_pretrained(self.pretrained_model)
self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model)
self.pipeline = hf_pipeline("text-classification", model=self.model, tokenizer=self.tokenizer, device='cpu')
def __call__(self, *args: Any) -> Any:
"""Performs classification on the given reports.
This method should be called only after the `load` method has been executed
to ensure that the model and pipeline are properly initialized. It accepts
arguments to pass to the Hugging Face text-classification pipeline.
Args:
*args: Variable length argument list to pass to the pipeline.
Returns:
The output of the text-classification pipeline.
Raises:
BrokenPipeError: If the model has not been loaded before calling this method.
"""
self.logger.info(f"Called with {args = }")
if self.pipeline is None:
msg = "Model was not initialized, have you run load()?"
raise BrokenPipeError(msg)
# check length of text
if len(args[0]) < 10:
return "Not enough text for classification!"
pipe_out = self.pipeline(*args)
pipe_out = {o['label']: o['score'] for o in pipe_out}
return pipe_out