Update README.md
Browse files
README.md
CHANGED
@@ -30,7 +30,7 @@ model_name = "IlyaGusev/rut5-base-sum-gazeta"
|
|
30 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
31 |
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
32 |
|
33 |
-
article_text = "
|
34 |
|
35 |
input_ids = tokenizer(
|
36 |
[article_text],
|
@@ -68,7 +68,7 @@ Predicting all summaries:
|
|
68 |
```python
|
69 |
import json
|
70 |
import torch
|
71 |
-
from transformers import
|
72 |
from datasets import load_dataset
|
73 |
|
74 |
|
@@ -89,8 +89,8 @@ def predict(
|
|
89 |
):
|
90 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
91 |
|
92 |
-
tokenizer =
|
93 |
-
model =
|
94 |
|
95 |
predictions = []
|
96 |
for batch in gen_batch(input_records, batch_size):
|
@@ -108,7 +108,6 @@ def predict(
|
|
108 |
input_ids=input_ids,
|
109 |
max_length=max_target_tokens_count,
|
110 |
no_repeat_ngram_size=3,
|
111 |
-
num_beams=5,
|
112 |
early_stopping=True
|
113 |
)
|
114 |
summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
@@ -120,7 +119,7 @@ def predict(
|
|
120 |
w.write(p.strip().replace("\n", " ") + "\n")
|
121 |
|
122 |
gazeta_test = load_dataset('IlyaGusev/gazeta', script_version="v1.0")["test"]
|
123 |
-
predict("IlyaGusev/
|
124 |
```
|
125 |
|
126 |
Evaluation: https://github.com/IlyaGusev/summarus/blob/master/evaluate.py
|
|
|
30 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
31 |
model = T5ForConditionalGeneration.from_pretrained(model_name)
|
32 |
|
33 |
+
article_text = "..."
|
34 |
|
35 |
input_ids = tokenizer(
|
36 |
[article_text],
|
|
|
68 |
```python
|
69 |
import json
|
70 |
import torch
|
71 |
+
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
72 |
from datasets import load_dataset
|
73 |
|
74 |
|
|
|
89 |
):
|
90 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
91 |
|
92 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
93 |
+
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
|
94 |
|
95 |
predictions = []
|
96 |
for batch in gen_batch(input_records, batch_size):
|
|
|
108 |
input_ids=input_ids,
|
109 |
max_length=max_target_tokens_count,
|
110 |
no_repeat_ngram_size=3,
|
|
|
111 |
early_stopping=True
|
112 |
)
|
113 |
summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
|
|
|
119 |
w.write(p.strip().replace("\n", " ") + "\n")
|
120 |
|
121 |
gazeta_test = load_dataset('IlyaGusev/gazeta', script_version="v1.0")["test"]
|
122 |
+
predict("IlyaGusev/rut5-base-sum-gazeta", list(gazeta_test), "t5_predictions.txt")
|
123 |
```
|
124 |
|
125 |
Evaluation: https://github.com/IlyaGusev/summarus/blob/master/evaluate.py
|