ChatVID / model /summary /TextSummarizer.py
Yiqin's picture
init
6ef31de
import torch
from transformers import T5ForConditionalGeneration, T5TokenizerFast
class TextSummarizer:
def __init__(self, device='cuda'):
self._load_model(
model_type="t5",
model_dir=
"./pretrained_models/flan-t5-large-finetuned-openai-summarize_from_feedback",
device=device)
def _load_model(self,
model_type: str = "t5",
model_dir: str = "outputs",
device: str = 'cuda'):
"""
loads a checkpoint for inferencing/prediction
Args:
model_type (str, optional): "t5" or "mt5". Defaults to "t5".
model_dir (str, optional): path to model directory.
Defaults to "outputs".
device (str, optional): device to run. Defaults to "cuda".
"""
if model_type == "t5":
self.model = T5ForConditionalGeneration.from_pretrained(
f"{model_dir}")
self.tokenizer = T5TokenizerFast.from_pretrained(f"{model_dir}")
else:
raise NotImplementedError(
f"model_type {model_type} not implemented")
self.device = torch.device(device)
self.model = self.model.to(self.device)
def _predict(
self,
source_text: str,
max_length: int = 512,
num_return_sequences: int = 1,
num_beams: int = 2,
top_k: int = 50,
top_p: float = 0.95,
do_sample: bool = True,
repetition_penalty: float = 2.5,
length_penalty: float = 1.0,
early_stopping: bool = True,
skip_special_tokens: bool = True,
clean_up_tokenization_spaces: bool = True,
):
"""
generates prediction for T5/MT5 model
Args:
source_text (str): any text for generating predictions
max_length (int, optional): max token length of prediction.
Defaults to 512.
num_return_sequences (int, optional): number of predictions to be
returned. Defaults to 1.
num_beams (int, optional): number of beams. Defaults to 2.
top_k (int, optional): Defaults to 50.
top_p (float, optional): Defaults to 0.95.
do_sample (bool, optional): Defaults to True.
repetition_penalty (float, optional): Defaults to 2.5.
length_penalty (float, optional): Defaults to 1.0.
early_stopping (bool, optional): Defaults to True.
skip_special_tokens (bool, optional): Defaults to True.
clean_up_tokenization_spaces (bool, optional): Defaults to True.
Returns:
list[str]: returns predictions
"""
input_ids = self.tokenizer.encode(
source_text, return_tensors="pt", add_special_tokens=True)
input_ids = input_ids.to(self.device)
generated_ids = self.model.generate(
input_ids=input_ids,
num_beams=num_beams,
max_length=max_length,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
early_stopping=early_stopping,
top_p=top_p,
top_k=top_k,
num_return_sequences=num_return_sequences,
)
preds = [
self.tokenizer.decode(
g,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
) for g in generated_ids
]
return preds
def __call__(self, source_text):
generated_text = self._predict(source_text=source_text)
return generated_text