File size: 767 Bytes
9a4b6ed
 
 
 
 
 
 
 
 
 
 
cd309a8
9a4b6ed
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

class T5Summarizer:
    def __init__(self, model_name: str = "fabiochiu/t5-small-medium-title-generation"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)

    def summarize(self, text: str):
        inputs = ["summarize: " + text]
        max_input_length = self.tokenizer.model_max_length
        inputs = self.tokenizer(inputs, max_length=max_input_length, truncation=True, return_tensors="tf")
        output = self.model.generate(**inputs, num_beams=12, do_sample=True, min_length=2, max_length=12)
        summary = self.tokenizer.batch_decode(output, skip_special_tokens=True)[0]
        return summary