Iker commited on
Commit
95d0c3a
1 Parent(s): 56e0241

Implement num_return_sequences parameter.

Browse files

Define how many possible translations you want for each source sentence. Defualt:1

Files changed (1) hide show
  1. translate.py +10 -1
translate.py CHANGED
@@ -59,6 +59,7 @@ def main(
59
  precision: str = "32",
60
  max_length: int = 128,
61
  num_beams: int = 4,
 
62
  ):
63
 
64
  if not os.path.exists(os.path.abspath(os.path.dirname(output_path))):
@@ -96,7 +97,7 @@ def main(
96
  gen_kwargs = {
97
  "max_length": max_length,
98
  "num_beams": num_beams,
99
- "num_return_sequences": 1,
100
  }
101
 
102
  # total_lines: int = count_lines(sentences_path)
@@ -246,6 +247,13 @@ if __name__ == "__main__":
246
  help="Number of beams for beam search, m2m10 author recommends 5, but it might use too much memory",
247
  )
248
 
 
 
 
 
 
 
 
249
  parser.add_argument(
250
  "--precision",
251
  type=str,
@@ -266,5 +274,6 @@ if __name__ == "__main__":
266
  cache_dir=args.cache_dir,
267
  max_length=args.max_length,
268
  num_beams=args.num_beams,
 
269
  precision=args.precision,
270
  )
 
59
  precision: str = "32",
60
  max_length: int = 128,
61
  num_beams: int = 4,
62
+ num_return_sequences: int = 1,
63
  ):
64
 
65
  if not os.path.exists(os.path.abspath(os.path.dirname(output_path))):
 
97
  gen_kwargs = {
98
  "max_length": max_length,
99
  "num_beams": num_beams,
100
+ "num_return_sequences": num_return_sequences,
101
  }
102
 
103
  # total_lines: int = count_lines(sentences_path)
 
247
  help="Number of beams for beam search, m2m10 author recommends 5, but it might use too much memory",
248
  )
249
 
250
+ parser.add_argument(
251
+ "--num_return_sequences",
252
+ type=int,
253
+ default=1,
254
+ help="Number of possible translation to return for each sentence (num_return_sequences<=num_beams).",
255
+ )
256
+
257
  parser.add_argument(
258
  "--precision",
259
  type=str,
 
274
  cache_dir=args.cache_dir,
275
  max_length=args.max_length,
276
  num_beams=args.num_beams,
277
+ num_return_sequences=args.num_return_sequences,
278
  precision=args.precision,
279
  )