|
--- |
|
tags: |
|
- generated_from_trainer |
|
- distilbart |
|
model-index: |
|
- name: distilbart-finetuned-summarization |
|
results: [] |
|
license: apache-2.0 |
|
datasets: |
|
- cnn_dailymail |
|
- xsum |
|
- samsum |
|
- ccdv/pubmed-summarization |
|
language: |
|
- en |
|
metrics: |
|
- rouge |
|
--- |
|
|
|
<!-- This model card has been generated automatically according to the information the Trainer had access to. You |
|
should probably proofread and complete it, then remove this comment. --> |
|
|
|
# distilbart-finetuned-summarization |
|
|
|
This model is a further fine-tuned version of [distilbart-cnn-12-6](https://huggingface.co/sshleifer/distilbart-cnn-12-6) on the the combination of 4 different summarisation datasets: |
|
- [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) |
|
- [samsum](https://huggingface.co/datasets/samsum) |
|
- [xsum](https://huggingface.co/datasets/xsum) |
|
- [ccdv/pubmed-summarization](https://huggingface.co/datasets/ccdv/pubmed-summarization) |
|
|
|
Please check out the offical model page and paper: |
|
- [sshleifer/distilbart-cnn-12-6](https://huggingface.co/sshleifer/distilbart-cnn-12-6) |
|
- [Pre-trained Summarization Distillation](https://arxiv.org/abs/2010.13002) |
|
|
|
## Training and evaluation data |
|
|
|
One can reproduce the dataset using the following code: |
|
|
|
```python |
|
from datasets import DatasetDict, load_dataset |
|
from datasets import concatenate_datasets |
|
|
|
xsum_dataset = load_dataset("xsum") |
|
pubmed_dataset = load_dataset("ccdv/pubmed-summarization").rename_column("article", "document").rename_column("abstract", "summary") |
|
cnn_dataset = load_dataset("cnn_dailymail", '3.0.0').rename_column("article", "document").rename_column("highlights", "summary") |
|
samsum_dataset = load_dataset("samsum").rename_column("dialogue", "document") |
|
|
|
summary_train = concatenate_datasets([xsum_dataset["train"], pubmed_dataset["train"], cnn_dataset["train"], samsum_dataset["train"]]) |
|
summary_validation = concatenate_datasets([xsum_dataset["validation"], pubmed_dataset["validation"], cnn_dataset["validation"], samsum_dataset["validation"]]) |
|
summary_test = concatenate_datasets([xsum_dataset["test"], pubmed_dataset["test"], cnn_dataset["test"], samsum_dataset["test"]]) |
|
|
|
raw_datasets = DatasetDict() |
|
raw_datasets["train"] = summary_train |
|
raw_datasets["validation"] = summary_validation |
|
raw_datasets["test"] = summary_test |
|
|
|
``` |
|
|
|
## Inference example |
|
|
|
```python |
|
from transformers import pipeline |
|
|
|
pipe = pipeline("text2text-generation", model="lxyuan/distilbart-finetuned-summarization") |
|
|
|
text = """SINGAPORE: The Singapore Police Force on Sunday (Jul 16) issued a warning over a |
|
fake SMS impersonating as its "anti-scam centre (ASC)". |
|
|
|
"In this scam variant, members of the public would receive a scam SMS from 'ASC', |
|
requesting them to download and install an “anti-scam” app to ensure the security |
|
of their devices," said the police. |
|
|
|
"The fake SMS would direct members of the public to a URL link leading to an |
|
Android Package Kit (APK) file, an application created for Android’s operating |
|
system purportedly from 'ASC'." |
|
|
|
The fake website has an icon to download the “anti-scam” app and once downloaded, |
|
Android users are asked to allow accessibility services to enable the service. |
|
|
|
While the fake app purportedly claims to help identify and prevent scams by |
|
providing comprehensive protection and security, downloading it may enable |
|
scammers to gain remote access to devices. |
|
|
|
"Members of the public are advised not to download any suspicious APK files |
|
on their devices as they may contain malware which will allow scammers to |
|
access and take control of the device remotely as well as to steal passwords |
|
stored in the device," said the police. |
|
|
|
Members of the public are advised to adopt the following precautionary measures, |
|
including adding anti-virus or anti-malware apps to their devices. They should |
|
also disable “install unknown app” or “unknown sources” in their phone settings. |
|
|
|
Users should check the developer information on the app listing as well as the |
|
number of downloads and user reviews to ensure it is a reputable and legitimate |
|
app, the police said. |
|
|
|
Any fraudulent transactions should be immediately reported to the banks. |
|
""" |
|
|
|
pipe(text) |
|
|
|
>>>"""The Singapore Police Force has issued a warning over a fake SMS |
|
impersonating as its "anti-scam centre" that asks members of the public |
|
to download an Android app to ensure the security of their devices, the |
|
force said on Sunday. The fake SMS would direct people to a URL link |
|
leading to an Android Package Kit (APK) file, an application created |
|
for Android’s operating system purportedly from "ASC". |
|
""" |
|
``` |
|
|
|
## Training procedure |
|
|
|
Notebook link: [here](https://github.com/LxYuan0420/nlp/blob/main/notebooks/distilbart-finetune-summarisation.ipynb) |
|
|
|
### Training hyperparameters |
|
|
|
The following hyperparameters were used during training: |
|
- evaluation_strategy="epoch", |
|
- save_strategy="epoch", |
|
- logging_strategy="epoch", |
|
- learning_rate=2e-5, |
|
- per_device_train_batch_size=2, |
|
- per_device_eval_batch_size=2, |
|
- gradient_accumulation_steps=64, |
|
- total_train_batch_size: 128 |
|
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08 |
|
- lr_scheduler_type: linear |
|
- weight_decay=0.01, |
|
- save_total_limit=2, |
|
- num_train_epochs=4, |
|
- predict_with_generate=True, |
|
- fp16=True, |
|
- push_to_hub=True |
|
|
|
### Training results |
|
_Training is still in progress_ |
|
|
|
| Epoch | Training Loss | Validation Loss | Rouge1 | Rouge2 | RougeL | RougeLsum | Gen Len | |
|
|-------|---------------|-----------------|--------|--------|--------|-----------|---------| |
|
| 0 | 1.779700 | 1.719054 | 40.003900 | 17.907100 | 27.882500 | 34.888600 | 88.893600 | |
|
| 1 | 1.633800 | 1.710876 | 40.628800 | 18.470200 | 28.428100 | 35.577500 | 88.885000 | |
|
| 2 | 1.566100 | 1.694476 | 40.928500 | 18.695300 | 28.613300 | 35.813300 | 88.993700 | |
|
| 3 | 1.515700 | 1.691141 | 40.860500 | 18.696500 | 28.672700 | 35.734600 | 88.457300 | |
|
|
|
|
|
### Framework versions |
|
|
|
- Transformers 4.30.2 |
|
- Pytorch 2.0.1+cu117 |
|
- Datasets 2.13.1 |
|
- Tokenizers 0.13.3 |
|
|