Spaces:
Build error
Build error
from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
from .base_single_doc_model import SingleDocSummModel | |
class PegasusModel(SingleDocSummModel): | |
# static variables | |
model_name = "Pegasus" | |
is_extractive = False | |
is_neural = True | |
def __init__(self, device="cpu"): | |
super(PegasusModel, self).__init__() | |
self.device = device | |
model_name = "google/pegasus-xsum" | |
print("init load pretrained tokenizer") | |
self.tokenizer = PegasusTokenizer.from_pretrained(model_name) | |
print("init load pretrained model with tokenizer on " + device) | |
# self.model = PegasusForConditionalGeneration.from_pretrained(model_name).to(device) | |
self.model = PegasusForConditionalGeneration.from_pretrained(model_name) | |
def summarize(self, corpus, queries=None): | |
self.assert_summ_input_type(corpus, queries) | |
print("batching") | |
# batch = self.tokenizer(corpus, truncation=True, padding='longest', return_tensors="pt").to(self.device) | |
batch = self.tokenizer(corpus, truncation=True, return_tensors="pt") | |
print("encoding batches") | |
# encoded_summaries = self.model.generate(**batch, max_length=40, max_time=120) | |
encoded_summaries = self.model.generate(batch["input_ids"], max_time=1024) | |
print("decoding batches") | |
# summaries = self.tokenizer.batch_decode(encoded_summaries, skip_special_tokens=True) | |
summaries = [self.tokenizer.decode(encoded_summaries[0])] | |
return summaries | |
def show_capability(cls): | |
basic_description = cls.generate_basic_description() | |
more_details = ( | |
"Introduced in 2019, a large neural abstractive summarization model trained on web crawl and " | |
"news data.\n " | |
"Strengths: \n - High accuracy \n - Performs well on almost all kinds of non-literary written " | |
"text \n " | |
"Weaknesses: \n - High memory usage \n " | |
"Initialization arguments: \n " | |
"- `device = 'cpu'` specifies the device the model is stored on and uses for computation. " | |
"Use `device='gpu'` to run on an Nvidia GPU." | |
) | |
print(f"{basic_description} \n {'#'*20} \n {more_details}") | |