Update README.md
Browse files
README.md
CHANGED
@@ -25,26 +25,25 @@ For more details, please see [Dataset for Automatic Summarization of Russian New
|
|
25 |
```python
|
26 |
from transformers import MBartTokenizer, MBartForConditionalGeneration
|
27 |
|
28 |
-
article_text = "..."
|
29 |
model_name = "IlyaGusev/mbart_ru_sum_gazeta"
|
30 |
tokenizer = MBartTokenizer.from_pretrained(model_name)
|
31 |
model = MBartForConditionalGeneration.from_pretrained(model_name)
|
32 |
|
33 |
-
|
|
|
34 |
[article_text],
|
35 |
-
|
36 |
-
return_tensors="pt",
|
37 |
padding="max_length",
|
38 |
truncation=True,
|
39 |
-
|
40 |
)["input_ids"]
|
41 |
|
42 |
output_ids = model.generate(
|
43 |
input_ids=input_ids,
|
44 |
-
|
45 |
)[0]
|
46 |
|
47 |
-
summary = tokenizer.decode(output_ids, skip_special_tokens=True
|
48 |
print(summary)
|
49 |
```
|
50 |
|
@@ -55,12 +54,12 @@ print(summary)
|
|
55 |
|
56 |
## Training data
|
57 |
|
58 |
-
- Dataset: https://
|
59 |
|
60 |
## Training procedure
|
61 |
|
62 |
-
- Fairseq training script: https://github.com/IlyaGusev/summarus/blob/master/external/bart_scripts/train.sh
|
63 |
-
- Porting: https://colab.research.google.com/drive/13jXOlCpArV-lm4jZQ0VgOpj6nFBYrLAr
|
64 |
|
65 |
## Eval results
|
66 |
|
@@ -98,7 +97,6 @@ def predict(
|
|
98 |
predictions_file,
|
99 |
targets_file,
|
100 |
max_source_tokens_count=600,
|
101 |
-
max_target_tokens_count=160,
|
102 |
use_cuda=True,
|
103 |
batch_size=4
|
104 |
):
|
@@ -115,9 +113,8 @@ def predict(
|
|
115 |
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
|
116 |
predictions = []
|
117 |
for batch in gen_batch(inputs, batch_size):
|
118 |
-
input_ids = tokenizer
|
119 |
batch,
|
120 |
-
src_lang="en_XX",
|
121 |
return_tensors="pt",
|
122 |
padding="max_length",
|
123 |
truncation=True,
|
@@ -125,12 +122,9 @@ def predict(
|
|
125 |
)["input_ids"].to(device)
|
126 |
output_ids = model.generate(
|
127 |
input_ids=input_ids,
|
128 |
-
|
129 |
-
no_repeat_ngram_size=3,
|
130 |
-
num_beams=5,
|
131 |
-
top_k=0
|
132 |
)
|
133 |
-
summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True
|
134 |
for s in summaries:
|
135 |
print(s)
|
136 |
predictions.extend(summaries)
|
|
|
25 |
```python
|
26 |
from transformers import MBartTokenizer, MBartForConditionalGeneration
|
27 |
|
|
|
28 |
model_name = "IlyaGusev/mbart_ru_sum_gazeta"
|
29 |
tokenizer = MBartTokenizer.from_pretrained(model_name)
|
30 |
model = MBartForConditionalGeneration.from_pretrained(model_name)
|
31 |
|
32 |
+
article_text = "..."
|
33 |
+
input_ids = tokenizer(
|
34 |
[article_text],
|
35 |
+
max_length=600,
|
|
|
36 |
padding="max_length",
|
37 |
truncation=True,
|
38 |
+
return_tensors="pt",
|
39 |
)["input_ids"]
|
40 |
|
41 |
output_ids = model.generate(
|
42 |
input_ids=input_ids,
|
43 |
+
repetition_penalty=3.0
|
44 |
)[0]
|
45 |
|
46 |
+
summary = tokenizer.decode(output_ids, skip_special_tokens=True)
|
47 |
print(summary)
|
48 |
```
|
49 |
|
|
|
54 |
|
55 |
## Training data
|
56 |
|
57 |
+
- Dataset: [Gazeta](https://huggingface.co/datasets/IlyaGusev/gazeta)
|
58 |
|
59 |
## Training procedure
|
60 |
|
61 |
+
- Fairseq training script: [train.sh](https://github.com/IlyaGusev/summarus/blob/master/external/bart_scripts/train.sh)
|
62 |
+
- Porting: [Colab link](https://colab.research.google.com/drive/13jXOlCpArV-lm4jZQ0VgOpj6nFBYrLAr)
|
63 |
|
64 |
## Eval results
|
65 |
|
|
|
97 |
predictions_file,
|
98 |
targets_file,
|
99 |
max_source_tokens_count=600,
|
|
|
100 |
use_cuda=True,
|
101 |
batch_size=4
|
102 |
):
|
|
|
113 |
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
|
114 |
predictions = []
|
115 |
for batch in gen_batch(inputs, batch_size):
|
116 |
+
input_ids = tokenizer(
|
117 |
batch,
|
|
|
118 |
return_tensors="pt",
|
119 |
padding="max_length",
|
120 |
truncation=True,
|
|
|
122 |
)["input_ids"].to(device)
|
123 |
output_ids = model.generate(
|
124 |
input_ids=input_ids,
|
125 |
+
repetition_penalty=3.0
|
|
|
|
|
|
|
126 |
)
|
127 |
+
summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
128 |
for s in summaries:
|
129 |
print(s)
|
130 |
predictions.extend(summaries)
|