dmitry-vorobiev
commited on
Commit
•
5b67024
1
Parent(s):
f70d5ba
upd readme
Browse files
README.md
CHANGED
@@ -10,7 +10,7 @@ license: MIT
|
|
10 |
|
11 |
## Description
|
12 |
*bert2bert* model, initialized with the `DeepPavlov/rubert-base-cased` pretrained weights and
|
13 |
-
fine-tuned on the first 90% of ["Rossiya Segodnya" news dataset](https://github.com/RossiyaSegodnya/ria_news_dataset) for
|
14 |
|
15 |
## Usage example
|
16 |
|
@@ -50,27 +50,26 @@ print(headline)
|
|
50 |
|
51 |
## How it was trained?
|
52 |
|
53 |
-
|
54 |
|
55 |
-
1. [
|
56 |
-
2. [
|
57 |
-
3. [
|
58 |
-
4. [1.6 ep](https://www.kaggle.com/dvorobiev/train-seq2seq?scriptVersionId=52876230)
|
59 |
|
60 |
Common train params:
|
61 |
|
62 |
```shell
|
63 |
python nlp_headline_rus/src/train_seq2seq.py \
|
64 |
--do_train \
|
65 |
-
--fp16 \
|
66 |
--tie_encoder_decoder \
|
67 |
--max_source_length 512 \
|
68 |
--max_target_length 32 \
|
69 |
--val_max_target_length 48 \
|
70 |
-
--
|
71 |
-
--
|
72 |
-
--
|
73 |
-
--
|
|
|
74 |
--adam_epsilon 1e-6 \
|
75 |
--weight_decay 1e-5 \
|
76 |
```
|
|
|
10 |
|
11 |
## Description
|
12 |
*bert2bert* model, initialized with the `DeepPavlov/rubert-base-cased` pretrained weights and
|
13 |
+
fine-tuned on the first 90% of ["Rossiya Segodnya" news dataset](https://github.com/RossiyaSegodnya/ria_news_dataset) for 3 epochs.
|
14 |
|
15 |
## Usage example
|
16 |
|
|
|
50 |
|
51 |
## How it was trained?
|
52 |
|
53 |
+
I used free TPUv3 on kaggle. The model was trained for 3 epochs with effective batch size 256 and soft restarts.
|
54 |
|
55 |
+
1. [1 ep](https://www.kaggle.com/dvorobiev/try-train-seq2seq-ria-tpu?scriptVersionId=53094837)
|
56 |
+
2. [2 ep](https://www.kaggle.com/dvorobiev/try-train-seq2seq-ria-tpu?scriptVersionId=53109219)
|
57 |
+
3. [3 ep](https://www.kaggle.com/dvorobiev/try-train-seq2seq-ria-tpu?scriptVersionId=53171375)
|
|
|
58 |
|
59 |
Common train params:
|
60 |
|
61 |
```shell
|
62 |
python nlp_headline_rus/src/train_seq2seq.py \
|
63 |
--do_train \
|
|
|
64 |
--tie_encoder_decoder \
|
65 |
--max_source_length 512 \
|
66 |
--max_target_length 32 \
|
67 |
--val_max_target_length 48 \
|
68 |
+
--tpu_num_cores 8 \
|
69 |
+
--per_device_train_batch_size 32 \
|
70 |
+
--gradient_accumulation_steps 1 \
|
71 |
+
--warmup_steps 500 \
|
72 |
+
--learning_rate 1e-3 \
|
73 |
--adam_epsilon 1e-6 \
|
74 |
--weight_decay 1e-5 \
|
75 |
```
|