IlyaGusev commited on
Commit
c587e90
1 Parent(s): d51c43e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +12 -19
README.md CHANGED
@@ -104,6 +104,7 @@ Predicting all summaries:
104
  import json
105
  import torch
106
  from transformers import MBartTokenizer, MBartForConditionalGeneration
 
107
 
108
 
109
  def gen_batch(inputs, batch_size):
@@ -115,26 +116,19 @@ def gen_batch(inputs, batch_size):
115
 
116
  def predict(
117
  model_name,
118
- test_file,
119
- predictions_file,
120
- targets_file,
121
  max_source_tokens_count=600,
122
- use_cuda=True,
123
  batch_size=4
124
  ):
125
- inputs = []
126
- targets = []
127
- with open(test_file, "r") as r:
128
- for line in r:
129
- record = json.loads(line)
130
- inputs.append(record["text"])
131
- targets.append(record["summary"].replace("\n", " "))
132
-
133
  tokenizer = MBartTokenizer.from_pretrained(model_name)
134
- device = torch.device("cuda:0") if use_cuda else torch.device("cpu")
135
  model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
 
136
  predictions = []
137
  for batch in gen_batch(inputs, batch_size):
 
138
  input_ids = tokenizer(
139
  batch,
140
  return_tensors="pt",
@@ -142,22 +136,21 @@ def predict(
142
  truncation=True,
143
  max_length=max_source_tokens_count
144
  )["input_ids"].to(device)
 
145
  output_ids = model.generate(
146
  input_ids=input_ids,
147
- repetition_penalty=3.0
148
  )
149
  summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
150
  for s in summaries:
151
  print(s)
152
  predictions.extend(summaries)
153
- with open(predictions_file, "w") as w:
154
  for p in predictions:
155
  w.write(p.strip().replace("\n", " ") + "\n")
156
- with open(targets_file, "w") as w:
157
- for t in targets:
158
- w.write(t.strip().replace("\n", " ") + "\n")
159
 
160
- predict("IlyaGusev/mbart_ru_sum_gazeta", "gazeta_test.jsonl", "predictions.txt", "targets.txt")
 
161
  ```
162
 
163
  Evaluation: https://github.com/IlyaGusev/summarus/blob/master/evaluate.py
104
  import json
105
  import torch
106
  from transformers import MBartTokenizer, MBartForConditionalGeneration
107
+ from datasets import load_dataset
108
 
109
 
110
  def gen_batch(inputs, batch_size):
116
 
117
  def predict(
118
  model_name,
119
+ input_records,
120
+ output_file,
 
121
  max_source_tokens_count=600,
 
122
  batch_size=4
123
  ):
124
+ device = "cuda" if torch.cuda.is_available() else "cpu"
125
+
 
 
 
 
 
 
126
  tokenizer = MBartTokenizer.from_pretrained(model_name)
 
127
  model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
128
+
129
  predictions = []
130
  for batch in gen_batch(inputs, batch_size):
131
+ texts = [r["text"] for r in batch]
132
  input_ids = tokenizer(
133
  batch,
134
  return_tensors="pt",
136
  truncation=True,
137
  max_length=max_source_tokens_count
138
  )["input_ids"].to(device)
139
+
140
  output_ids = model.generate(
141
  input_ids=input_ids,
142
+ no_repeat_ngram_size=4
143
  )
144
  summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
145
  for s in summaries:
146
  print(s)
147
  predictions.extend(summaries)
148
+ with open(output_file, "w") as w:
149
  for p in predictions:
150
  w.write(p.strip().replace("\n", " ") + "\n")
 
 
 
151
 
152
+ gazeta_test = load_dataset('IlyaGusev/gazeta', script_version="v1.0")["test"]
153
+ predict("IlyaGusev/mbart_ru_sum_gazeta", list(gazeta_test), "mbart_predictions.txt")
154
  ```
155
 
156
  Evaluation: https://github.com/IlyaGusev/summarus/blob/master/evaluate.py