IlyaGusev commited on
Commit
4c9d0ca
1 Parent(s): d6c98e3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +4 -7
README.md CHANGED
@@ -45,7 +45,7 @@ article_text = "..."
45
 
46
  input_ids = tokenizer(
47
  [article_text],
48
- max_length=400,
49
  add_special_tokens=True,
50
  padding="max_length",
51
  truncation=True,
@@ -116,9 +116,8 @@ def predict(
116
  model_name,
117
  input_records,
118
  output_file,
119
- max_source_tokens_count=400,
120
- max_target_tokens_count=200,
121
- batch_size=16
122
  ):
123
  device = "cuda" if torch.cuda.is_available() else "cpu"
124
 
@@ -139,9 +138,7 @@ def predict(
139
 
140
  output_ids = model.generate(
141
  input_ids=input_ids,
142
- max_length=max_target_tokens_count,
143
- no_repeat_ngram_size=3,
144
- early_stopping=True
145
  )
146
  summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
147
  for s in summaries:
45
 
46
  input_ids = tokenizer(
47
  [article_text],
48
+ max_length=600,
49
  add_special_tokens=True,
50
  padding="max_length",
51
  truncation=True,
116
  model_name,
117
  input_records,
118
  output_file,
119
+ max_source_tokens_count=600,
120
+ batch_size=8
 
121
  ):
122
  device = "cuda" if torch.cuda.is_available() else "cpu"
123
 
138
 
139
  output_ids = model.generate(
140
  input_ids=input_ids,
141
+ no_repeat_ngram_size=4
 
 
142
  )
143
  summaries = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
144
  for s in summaries: