IlyaGusev commited on
Commit
96cf49a
1 Parent(s): 3798869

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +12 -18
README.md CHANGED
@@ -25,26 +25,25 @@ For more details, please see [Dataset for Automatic Summarization of Russian New
25
  ```python
26
  from transformers import MBartTokenizer, MBartForConditionalGeneration
27
 
28
- article_text = "..."
29
  model_name = "IlyaGusev/mbart_ru_sum_gazeta"
30
  tokenizer = MBartTokenizer.from_pretrained(model_name)
31
  model = MBartForConditionalGeneration.from_pretrained(model_name)
32
 
33
- input_ids = tokenizer.prepare_seq2seq_batch(
 
34
  [article_text],
35
- src_lang="en_XX", # fairseq training artifact
36
- return_tensors="pt",
37
  padding="max_length",
38
  truncation=True,
39
- max_length=600
40
  )["input_ids"]
41
 
42
  output_ids = model.generate(
43
  input_ids=input_ids,
44
- no_repeat_ngram_size=3
45
  )[0]
46
 
47
- summary = tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
48
  print(summary)
49
  ```
50
 
@@ -55,12 +54,12 @@ print(summary)
55
 
56
  ## Training data
57
 
58
- - Dataset: https://github.com/IlyaGusev/gazeta
59
 
60
  ## Training procedure
61
 
62
- - Fairseq training script: https://github.com/IlyaGusev/summarus/blob/master/external/bart_scripts/train.sh
63
- - Porting: https://colab.research.google.com/drive/13jXOlCpArV-lm4jZQ0VgOpj6nFBYrLAr
64
 
65
  ## Eval results
66
 
@@ -98,7 +97,6 @@ def predict(
98
  predictions_file,
99
  targets_file,
100
  max_source_tokens_count=600,
101
- max_target_tokens_count=160,
102
  use_cuda=True,
103
  batch_size=4
104
  ):
@@ -115,9 +113,8 @@ def predict(
115
  model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
116
  predictions = []
117
  for batch in gen_batch(inputs, batch_size):
118
- input_ids = tokenizer.prepare_seq2seq_batch(
119
  batch,
120
- src_lang="en_XX",
121
  return_tensors="pt",
122
  padding="max_length",
123
  truncation=True,
@@ -125,12 +122,9 @@ def predict(
125
  )["input_ids"].to(device)
126
  output_ids = model.generate(
127
  input_ids=input_ids,
128
- max_length=max_target_tokens_count + 2,
129
- no_repeat_ngram_size=3,
130
- num_beams=5,
131
- top_k=0
132
  )
133
- summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
134
  for s in summaries:
135
  print(s)
136
  predictions.extend(summaries)
 
25
  ```python
26
  from transformers import MBartTokenizer, MBartForConditionalGeneration
27
 
 
28
  model_name = "IlyaGusev/mbart_ru_sum_gazeta"
29
  tokenizer = MBartTokenizer.from_pretrained(model_name)
30
  model = MBartForConditionalGeneration.from_pretrained(model_name)
31
 
32
+ article_text = "..."
33
+ input_ids = tokenizer(
34
  [article_text],
35
+ max_length=600,
 
36
  padding="max_length",
37
  truncation=True,
38
+ return_tensors="pt",
39
  )["input_ids"]
40
 
41
  output_ids = model.generate(
42
  input_ids=input_ids,
43
+ repetition_penalty=3.0
44
  )[0]
45
 
46
+ summary = tokenizer.decode(output_ids, skip_special_tokens=True)
47
  print(summary)
48
  ```
49
 
 
54
 
55
  ## Training data
56
 
57
+ - Dataset: [Gazeta](https://huggingface.co/datasets/IlyaGusev/gazeta)
58
 
59
  ## Training procedure
60
 
61
+ - Fairseq training script: [train.sh](https://github.com/IlyaGusev/summarus/blob/master/external/bart_scripts/train.sh)
62
+ - Porting: [Colab link](https://colab.research.google.com/drive/13jXOlCpArV-lm4jZQ0VgOpj6nFBYrLAr)
63
 
64
  ## Eval results
65
 
 
97
  predictions_file,
98
  targets_file,
99
  max_source_tokens_count=600,
 
100
  use_cuda=True,
101
  batch_size=4
102
  ):
 
113
  model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
114
  predictions = []
115
  for batch in gen_batch(inputs, batch_size):
116
+ input_ids = tokenizer(
117
  batch,
 
118
  return_tensors="pt",
119
  padding="max_length",
120
  truncation=True,
 
122
  )["input_ids"].to(device)
123
  output_ids = model.generate(
124
  input_ids=input_ids,
125
+ repetition_penalty=3.0
 
 
 
126
  )
127
+ summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
128
  for s in summaries:
129
  print(s)
130
  predictions.extend(summaries)