document-summarization / aggregate.py
pszemraj's picture
✨ mwe working aggregation
e414859
raw
history blame
7.2 kB
import pprint as pp
import logging
import time
import torch
from transformers import GenerationConfig, pipeline
from utils import compare_model_size
# Setting up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class BatchAggregator:
CONFIGURED_MODELS = [
"pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
] # TODO: Add models here
DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
GENERIC_CONFIG = GenerationConfig(
num_beams=8,
early_stopping=True,
do_sample=False,
min_new_tokens=32,
max_new_tokens=256,
repetition_penalty=1.1,
length_penalty=1.4,
no_repeat_ngram_size=4,
encoder_no_repeat_ngram_size=5,
)
def __init__(
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
):
self.device = None
self.is_compiled = False
self.logger = logging.getLogger(__name__)
self.init_model(model_name)
def init_model(self, model_name: str) -> None:
"""
Initialize the model.
:param model_name: The name of the model to use.
"""
# Free up memory
if torch.cuda.is_available():
torch.cuda.empty_cache()
self.logger.info(f"Setting model to {model_name}")
self.model_name = model_name
self.aggregator = self._create_pipeline(model_name)
self._configure_model()
# update the generation config with the specific tokenizer
tokenizer_params = {
"decoder_start_token_id": 0
if "t5" in model_name.lower()
else self.aggregator.tokenizer.eos_token_id,
"eos_token_id": 1
if "t5" in model_name.lower()
else self.aggregator.tokenizer.eos_token_id,
"pad_token_id": 0
if "t5" in model_name.lower()
else self.aggregator.tokenizer.pad_token_id,
}
self.update_generation_config(**tokenizer_params)
def _create_pipeline(
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
) -> pipeline:
"""
_create_pipeline creates a pipeline for the model.
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
:return pipeline: the pipeline for the model
:raises Exception: if the pipeline cannot be created
"""
self.device = 0 if torch.cuda.is_available() else -1
try:
self.logger.info(
f"Creating pipeline with model {model_name} on device {self.device}"
)
return pipeline(
"text2text-generation",
model_name,
device=self.device,
torch_dtype=torch.float32,
)
except Exception as e:
self.logger.error(f"Failed to create pipeline: {e}")
raise
def _configure_model(self):
"""
Configure the model for generation.
"""
try:
self.aggregator.model = torch.compile(self.aggregator.model)
self.is_compiled = True
except Exception as e:
self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
if self.model_name not in self.CONFIGURED_MODELS:
self.logger.info("Setting generation config to general defaults")
self._set_default_generation_config()
else:
try:
self.logger.info("Loading generation config from hub")
self.aggregator.model.generation_config = (
GenerationConfig.from_pretrained(self.model_name)
)
except Exception as e:
self.logger.warning(
f"Could not load generation config, using defaults: {e}"
)
self._set_default_generation_config()
self.logger.info(self.aggregator.model.generation_config.to_json_string())
def _set_default_generation_config(self):
"""
Set the default generation configuration for the model.
"""
self.aggregator.model.generation_config = self.GENERIC_CONFIG
if "bart" in self.model_name.lower():
self.logger.info("Using BART model, updating generation config")
upd = {
"num_beams": 8,
"repetition_penalty": 1.3,
"length_penalty": 1.0,
"_from_model_config": False,
"max_new_tokens": 256,
"min_new_tokens": 32,
"no_repeat_ngram_size": 3,
"encoder_no_repeat_ngram_size": 6,
} # TODO: clean up
self.aggregator.model.generation_config.update(**upd)
if (
"large"
or "xl" in self.model_name.lower()
or compare_model_size(self.model_name, 500)
):
upd = {"num_beams": 4}
self.update_generation_config(**upd)
def update_generation_config(self, **kwargs):
"""
Update the generation configuration with the specified parameters.
Args:
**kwargs: The parameters to update in the generation configuration.
"""
self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}")
self.aggregator.model.generation_config.update(**kwargs)
def update_loglevel(self, level: str = "INFO"):
"""
Update the log level.
Args:
level (str): The log level to set. Defaults to "INFO".
"""
self.logger.setLevel(level)
def infer_aggregate(
self,
text_list: list,
instruction: str = DEFAULT_INSTRUCTION,
**kwargs,
) -> str:
f"""
Generate a summary of the specified texts.
Args:
text_list (list): The texts to summarize.
instruction (str): The instruction for the summary. Defaults to {self.DEFAULT_INSTRUCTION}.
**kwargs: Additional parameters to update in the generation configuration.
Returns:
The generated summary.
"""
joined_text = "\n".join(text_list)
prompt = f"{instruction}\n\n{joined_text}\n"
if kwargs:
self.update_generation_config(**kwargs)
st = time.perf_counter()
self.logger.info(f"inference on {len(text_list)} texts ...")
result = self.aggregator(
prompt,
generation_config=self.aggregator.model.generation_config,
)[0]["generated_text"]
self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s")
self.logger.info(
f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}"
)
return result
def count_tokens(self, text: str) -> int:
"""count the number of tokens in a text"""
return (
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
if text
else 0
)