dmitry-vorobiev's picture
Update README.md
5c346e9
|
raw
history blame
No virus
2.69 kB
---
language:
- ru
tags:
- summarization
- bert
- rubert
license: MIT
---
# rubert_ria_headlines
## Description
*bert2bert* model, initialized with the `DeepPavlov/rubert-base-cased` pretrained weights and
fine-tuned on the first 90% of ["Rossiya Segodnya" news dataset](https://github.com/RossiyaSegodnya/ria_news_dataset) for 3 epochs.
## Usage example
```python
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
MODEL_NAME = "dmitry-vorobiev/rubert_ria_headlines"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
text = "Скопируйте текст статьи / новости"
encoded_batch = tokenizer.prepare_seq2seq_batch(
[text],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512)
output_ids = model.generate(
input_ids=encoded_batch["input_ids"],
max_length=32,
no_repeat_ngram_size=3,
num_beams=5,
top_k=0
)
headline = tokenizer.decode(output_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
print(headline)
```
## Datasets
- [ria_news](https://github.com/RossiyaSegodnya/ria_news_dataset)
## How it was trained?
I used free TPUv3 on kaggle. The model was trained for 3 epochs with effective batch size 192 and soft restarts (warmup steps 1500 / 500 / 500 with new optimizer state on each epoch start).
- [1 epoch notebook](https://www.kaggle.com/dvorobiev/try-train-seq2seq-ria-tpu?scriptVersionId=53254694)
- [2 epoch notebook](https://www.kaggle.com/dvorobiev/try-train-seq2seq-ria-tpu?scriptVersionId=53269040)
- [3 epoch notebook](https://www.kaggle.com/dvorobiev/try-train-seq2seq-ria-tpu?scriptVersionId=53280797)
Common train params:
```shell
export XLA_USE_BF16=1
export XLA_TENSOR_ALLOCATOR_MAXSIZE=100000000
python nlp_headline_rus/src/train_seq2seq.py \
--do_train \
--tie_encoder_decoder \
--max_source_length 512 \
--max_target_length 32 \
--val_max_target_length 48 \
--tpu_num_cores 8 \
--per_device_train_batch_size 24 \
--gradient_accumulation_steps 1 \
--learning_rate 5e-4 \
--adam_epsilon 1e-6 \
--weight_decay 1e-5 \
```
## Validation results
- Using [last 1% of ria](https://drive.google.com/drive/folders/1xtCnkbGNNu5jGQ9H9Mg55Cx7RTcyhQw9) dataset
- Using [last 10% of ria](https://drive.google.com/drive/folders/1w6rAXhpFUO8I4A7xfHKUjMBPEKBHEO3h) dataset
- Using [gazeta_ru test](https://drive.google.com/drive/folders/185ALuNVbbT_C1ZHQYn1OlOc9vRVILvHs) split
- Using [gazeta_ru val](https://drive.google.com/drive/folders/1BLiL3H0n56e8Q9jSuDgaH_3LLpmKxuVG) split