File size: 4,512 Bytes
ae42ffb 0b39e48 ae42ffb c50d07b ae42ffb daba94b ae42ffb daba94b ae42ffb daba94b ae42ffb daba94b ae42ffb daba94b ae42ffb d5f1442 ae42ffb daba94b 37a48b2 daba94b 310a4d4 daba94b 77a0e01 daba94b 77a0e01 daba94b ae42ffb daba94b ae42ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
---
language:
- ru
- ru-RU
tags:
- summarization
- mbart
datasets:
- IlyaGusev/gazeta
license: apache-2.0
---
# MBARTRuSumGazeta
## Model description
This is a ported version of [fairseq model](https://www.dropbox.com/s/fijtntnifbt9h0k/gazeta_mbart_v2_fairseq.tar.gz).
For more details, please see, [Dataset for Automatic Summarization of Russian News](https://arxiv.org/abs/2006.11063).
## Intended uses & limitations
#### How to use
```python
from transformers import MBartTokenizer, MBartForConditionalGeneration
article_text = "..."
model_name = "IlyaGusev/mbart_ru_sum_gazeta"
tokenizer = MBartTokenizer.from_pretrained(model_name)
model = MBartForConditionalGeneration.from_pretrained(model_name)
input_ids = tokenizer.prepare_seq2seq_batch(
[article_text],
src_lang="en_XX", # fairseq training artifact
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=600
)["input_ids"]
output_ids = model.generate(
input_ids=input_ids,
max_length=162,
no_repeat_ngram_size=3,
num_beams=5,
top_k=0
)[0]
summary = tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(summary)
```
#### Limitations and bias
- The model should work well with Gazeta.ru articles, but for any other agencies it can suffer from domain shift
## Training data
- Dataset: https://github.com/IlyaGusev/gazeta
## Training procedure
- Fairseq training script: https://github.com/IlyaGusev/summarus/blob/master/external/bart_scripts/train.sh
- Porting: https://colab.research.google.com/drive/13jXOlCpArV-lm4jZQ0VgOpj6nFBYrLAr
## Eval results
| Model | R-1-f | R-2-f | R-L-f | METEOR | BLEU |
|:--------------------------|:------|:------|:------|:-------|:-----|
| gazeta_mbart | 32.3 | 14.3 | 27.9 | 25.5 | 12.4 |
Predicting all summaries:
```python
import json
import torch
from transformers import MBartTokenizer, MBartForConditionalGeneration
def gen_batch(inputs, batch_size):
batch_start = 0
while batch_start < len(inputs):
yield inputs[batch_start: batch_start + batch_size]
batch_start += batch_size
def predict(
model_name,
test_file,
predictions_file,
targets_file,
max_source_tokens_count=600,
max_target_tokens_count=160,
use_cuda=True,
batch_size=4
):
inputs = []
targets = []
with open(test_file, "r") as r:
for line in r:
record = json.loads(line)
inputs.append(record["text"])
targets.append(record["summary"].replace("\n", " "))
tokenizer = MBartTokenizer.from_pretrained(model_name)
device = torch.device("cuda:0") if use_cuda else torch.device("cpu")
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
predictions = []
for batch in gen_batch(inputs, batch_size):
input_ids = tokenizer.prepare_seq2seq_batch(
batch,
src_lang="en_XX",
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_source_tokens_count
)["input_ids"].to(device)
output_ids = model.generate(
input_ids=input_ids,
max_length=max_target_tokens_count + 2,
no_repeat_ngram_size=3,
num_beams=5,
top_k=0
)
summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for s in summaries:
print(s)
predictions.extend(summaries)
with open(predictions_file, "w") as w:
for p in predictions:
w.write(p.strip().replace("\n", " ") + "\n")
with open(targets_file, "w") as w:
for t in targets:
w.write(t.strip().replace("\n", " ") + "\n")
predict("IlyaGusev/mbart_ru_sum_gazeta", "gazeta_test.jsonl", "predictions.txt", "targets.txt")
```
Evaluation: https://github.com/IlyaGusev/summarus/blob/master/evaluate.py
Flags: --language ru --tokenize-after --lower
### BibTeX entry and citation info
```bibtex
@InProceedings{10.1007/978-3-030-59082-6_9,
author="Gusev, Ilya",
editor="Filchenkov, Andrey and Kauttonen, Janne and Pivovarova, Lidia",
title="Dataset for Automatic Summarization of Russian News",
booktitle="Artificial Intelligence and Natural Language",
year="2020",
publisher="Springer International Publishing",
address="Cham",
pages="122--134",
isbn="978-3-030-59082-6"
}
```
|