Spaces:
Build error
Build error
from transformers import BartForConditionalGeneration, BartTokenizer | |
from .base_single_doc_model import SingleDocSummModel | |
class BartModel(SingleDocSummModel): | |
# static variables | |
model_name = "BART" | |
is_extractive = False | |
is_neural = False | |
def __init__(self, device="cpu"): | |
super(BartModel, self).__init__() | |
self.device = device | |
model_name = "facebook/bart-large-cnn" | |
self.tokenizer = BartTokenizer.from_pretrained(model_name) | |
self.model = BartForConditionalGeneration.from_pretrained(model_name) | |
def summarize(self, corpus, queries=None): | |
self.assert_summ_input_type(corpus, queries) | |
batch = self.tokenizer( | |
corpus, truncation=True, padding="longest", return_tensors="pt" | |
).to(self.device) | |
encoded_summaries = self.model.generate(**batch) | |
summaries = self.tokenizer.batch_decode( | |
encoded_summaries, skip_special_tokens=True | |
) | |
return summaries | |
def show_capability(cls) -> None: | |
# TODO zhangir: add the show capability function for BART | |
print(cls.generate_basic_description()) | |