from transformers import AutoTokenizer, AutoModelForSequenceClassification from transformers import pipeline as hf_pipeline from pathlib import Path from typing import Any, Dict from .app_logger import get_logger class NpcBertCLS(): r"""A class for performing report classification with BERT. This class facilitates report classification tasks using a BERT model fine-tuned on NPC staging reports. The base model is an uncased model released by Microsoft, which can be found on the Hugging Face model hub under the name 'microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext'. Attributes: model (transformers.PreTrainedModel): The fine-tuned BERT model for sequence classification. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the BERT model. pipeline (transformers.text-classification): The Hugging Face text-classification pipeline. pretrained_model (str): The path to the directory containing the fine-tuned model. """ logger = get_logger() def __init__(self): self.model = None self.tokenizer = None self.pipeline = None # relative to app.py self.pretrained_model = "./models/npc-bert-cls" self.logger.info(f"Created {__class__.__name__} instance.") def load(self) -> None: """Loads the fine-tuned BERT model and related components. This method initializes the model, tokenizer, and pipeline for the text classification tasks using the pre-trained weights from the specified directory. Raises: FileNotFoundError: If the pretrained model directory is not found. """ if not Path(self.pretrained_model).is_dir(): raise FileNotFoundError(f"Cannot found pretrained model at: {self.pretrained_model}") self.model = AutoModelForSequenceClassification.from_pretrained(self.pretrained_model) self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model) self.pipeline = hf_pipeline("text-classification", model=self.model, tokenizer=self.tokenizer, device='cpu') def __call__(self, *args: Any) -> Any: """Performs classification on the given reports. This method should be called only after the `load` method has been executed to ensure that the model and pipeline are properly initialized. It accepts arguments to pass to the Hugging Face text-classification pipeline. Args: *args: Variable length argument list to pass to the pipeline. Returns: The output of the text-classification pipeline. Raises: BrokenPipeError: If the model has not been loaded before calling this method. """ self.logger.info(f"Called with {args = }") if self.pipeline is None: msg = "Model was not initialized, have you run load()?" raise BrokenPipeError(msg) # check length of text if len(args[0]) < 10: return "Not enough text for classification!" pipe_out = self.pipeline(*args) pipe_out = {o['label']: o['score'] for o in pipe_out} return pipe_out