|
--- |
|
language: |
|
- ru |
|
- ru-RU |
|
tags: |
|
- summarization |
|
- t5 |
|
datasets: |
|
- IlyaGusev/gazeta |
|
license: apache-2.0 |
|
--- |
|
|
|
# RuT5SumGazeta |
|
|
|
## Model description |
|
|
|
This is the model for abstractive summarization for Russian based on [rut5-base](https://huggingface.co/cointegrated/rut5-base). |
|
|
|
|
|
## Intended uses & limitations |
|
|
|
#### How to use |
|
|
|
```python |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
|
|
article_text = "..." |
|
|
|
model_name = "IlyaGusev/rut5-base-sum-gazeta" |
|
tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
model = T5ForConditionalGeneration.from_pretrained(model_name) |
|
|
|
input_ids = tokenizer( |
|
[article_text], |
|
add_special_tokens=True, |
|
padding="max_length", |
|
truncation=True, |
|
max_length=400, |
|
return_tensors="pt" |
|
)["input_ids"] |
|
|
|
output_ids = model.generate( |
|
input_ids=input_ids, |
|
max_length=200, |
|
no_repeat_ngram_size=3, |
|
num_beams=5, |
|
early_stopping=True |
|
)[0] |
|
|
|
summary = tokenizer.decode(output_ids, skip_special_tokens=True) |
|
print(summary) |
|
``` |
|
|
|
## Training data |
|
|
|
- Dataset: https://github.com/IlyaGusev/gazeta |
|
|
|
## Training procedure |
|
|
|
- Training script: [TBA] |
|
|
|
## Eval results |
|
|
|
| Model | R-1-f | R-2-f | R-L-f | chrF | BLEU | |
|
|:--------------------------|:------|:------|:------|:-----|:-----| |
|
| rut5-base-sum-gazeta | 32.3 | 14.5 | 27.9 | 39.6 | 11.5 | |
|
|
|
Predicting all summaries: |
|
```python |
|
import json |
|
import torch |
|
from transformers import T5Tokenizer, T5ForConditionalGeneration |
|
from datasets import load_dataset |
|
|
|
|
|
def gen_batch(inputs, batch_size): |
|
batch_start = 0 |
|
while batch_start < len(inputs): |
|
yield inputs[batch_start: batch_start + batch_size] |
|
batch_start += batch_size |
|
|
|
|
|
def predict( |
|
model_name, |
|
input_records, |
|
output_file, |
|
max_source_tokens_count=400, |
|
max_target_tokens_count=200, |
|
batch_size=16 |
|
): |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
tokenizer = MBartTokenizer.from_pretrained(model_name) |
|
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device) |
|
|
|
predictions = [] |
|
for batch in gen_batch(input_records, batch_size): |
|
texts = [r["text"] for r in batch] |
|
input_ids = tokenizer( |
|
texts, |
|
add_special_tokens=True, |
|
max_length=max_source_tokens_count, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
)["input_ids"].to(device) |
|
|
|
output_ids = model.generate( |
|
input_ids=input_ids, |
|
max_length=max_target_tokens_count, |
|
no_repeat_ngram_size=3, |
|
num_beams=5, |
|
early_stopping=True |
|
) |
|
summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
|
for s in summaries: |
|
print(s) |
|
predictions.extend(summaries) |
|
with open(output_file, "w") as w: |
|
for p in predictions: |
|
w.write(p.strip().replace("\n", " ") + "\n") |
|
|
|
gazeta_test = load_dataset('IlyaGusev/gazeta', script_version="v1.0")["test"] |
|
predict("IlyaGusev/mbart_ru_sum_gazeta", gazeta_test["test"], "t5_predictions.txt") |
|
``` |
|
|
|
Evaluation: https://github.com/IlyaGusev/summarus/blob/master/evaluate.py |
|
|
|
Flags: --language ru --tokenize-after --lower |
|
|