import pprint as pp import logging import time import torch from transformers import GenerationConfig, pipeline from utils import compare_model_size # Setting up logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) class BatchAggregator: CONFIGURED_MODELS = [ "pszemraj/bart-large-mnli-dolly_hhrlhf-v1" ] # TODO: Add models here DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:" GENERIC_CONFIG = GenerationConfig( num_beams=8, early_stopping=True, do_sample=False, min_new_tokens=32, max_new_tokens=256, repetition_penalty=1.1, length_penalty=1.4, no_repeat_ngram_size=4, encoder_no_repeat_ngram_size=5, ) def __init__( self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs ): self.device = None self.is_compiled = False self.logger = logging.getLogger(__name__) self.init_model(model_name) def init_model(self, model_name: str) -> None: """ Initialize the model. :param model_name: The name of the model to use. """ # Free up memory if torch.cuda.is_available(): torch.cuda.empty_cache() self.logger.info(f"Setting model to {model_name}") self.model_name = model_name self.aggregator = self._create_pipeline(model_name) self._configure_model() # update the generation config with the specific tokenizer tokenizer_params = { "decoder_start_token_id": 0 if "t5" in model_name.lower() else self.aggregator.tokenizer.eos_token_id, "eos_token_id": 1 if "t5" in model_name.lower() else self.aggregator.tokenizer.eos_token_id, "pad_token_id": 0 if "t5" in model_name.lower() else self.aggregator.tokenizer.pad_token_id, } self.update_generation_config(**tokenizer_params) def _create_pipeline( self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1" ) -> pipeline: """ _create_pipeline creates a pipeline for the model. :param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1" :return pipeline: the pipeline for the model :raises Exception: if the pipeline cannot be created """ self.device = 0 if torch.cuda.is_available() else -1 try: self.logger.info( f"Creating pipeline with model {model_name} on device {self.device}" ) return pipeline( "text2text-generation", model_name, device=self.device, torch_dtype=torch.float32, ) except Exception as e: self.logger.error(f"Failed to create pipeline: {e}") raise def _configure_model(self): """ Configure the model for generation. """ try: self.aggregator.model = torch.compile(self.aggregator.model) self.is_compiled = True except Exception as e: self.logger.warning(f"Could not compile model with Torch 2.0: {e}") if self.model_name not in self.CONFIGURED_MODELS: self.logger.info("Setting generation config to general defaults") self._set_default_generation_config() else: try: self.logger.info("Loading generation config from hub") self.aggregator.model.generation_config = ( GenerationConfig.from_pretrained(self.model_name) ) except Exception as e: self.logger.warning( f"Could not load generation config, using defaults: {e}" ) self._set_default_generation_config() self.logger.info(self.aggregator.model.generation_config.to_json_string()) def _set_default_generation_config(self): """ Set the default generation configuration for the model. """ self.aggregator.model.generation_config = self.GENERIC_CONFIG if "bart" in self.model_name.lower(): self.logger.info("Using BART model, updating generation config") upd = { "num_beams": 8, "repetition_penalty": 1.3, "length_penalty": 1.0, "_from_model_config": False, "max_new_tokens": 256, "min_new_tokens": 32, "no_repeat_ngram_size": 3, "encoder_no_repeat_ngram_size": 6, } # TODO: clean up self.aggregator.model.generation_config.update(**upd) if ( "large" or "xl" in self.model_name.lower() or compare_model_size(self.model_name, 500) ): upd = {"num_beams": 4} self.update_generation_config(**upd) def update_generation_config(self, **kwargs): """ Update the generation configuration with the specified parameters. Args: **kwargs: The parameters to update in the generation configuration. """ self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}") self.aggregator.model.generation_config.update(**kwargs) def update_loglevel(self, level: str = "INFO"): """ Update the log level. Args: level (str): The log level to set. Defaults to "INFO". """ self.logger.setLevel(level) def infer_aggregate( self, text_list: list, instruction: str = DEFAULT_INSTRUCTION, **kwargs, ) -> str: f""" Generate a summary of the specified texts. Args: text_list (list): The texts to summarize. instruction (str): The instruction for the summary. Defaults to {self.DEFAULT_INSTRUCTION}. **kwargs: Additional parameters to update in the generation configuration. Returns: The generated summary. """ joined_text = "\n".join(text_list) prompt = f"{instruction}\n\n{joined_text}\n" if kwargs: self.update_generation_config(**kwargs) st = time.perf_counter() self.logger.info(f"inference on {len(text_list)} texts ...") result = self.aggregator( prompt, generation_config=self.aggregator.model.generation_config, )[0]["generated_text"] self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s") self.logger.info( f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}" ) return result def count_tokens(self, text: str) -> int: """count the number of tokens in a text""" return ( len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False)) if text else 0 )