pszemraj commited on
Commit
dcce2ac
1 Parent(s): 1d116fb

add cli arg to 🅱️oost 🅱️eams

Browse files

Signed-off-by: peter szemraj <peterszemraj@gmail.com>

Files changed (2) hide show
  1. aggregate.py +9 -0
  2. app.py +14 -0
aggregate.py CHANGED
@@ -179,6 +179,15 @@ class BatchAggregator:
179
 
180
  self.aggregator.model.generation_config.update(**kwargs)
181
 
 
 
 
 
 
 
 
 
 
182
  def update_loglevel(self, level: str = "INFO"):
183
  """
184
  Update the log level.
 
179
 
180
  self.aggregator.model.generation_config.update(**kwargs)
181
 
182
+ def get_generation_config(self) -> dict:
183
+ """
184
+ Get the current generation configuration.
185
+
186
+ Returns:
187
+ dict: The current generation configuration.
188
+ """
189
+ return self.aggregator.model.generation_config.to_dict()
190
+
191
  def update_loglevel(self, level: str = "INFO"):
192
  """
193
  Update the log level.
app.py CHANGED
@@ -427,6 +427,14 @@ def parse_args():
427
  default=None,
428
  help=f"Add a token batch size to the demo UI options, default: {pp.pformat(TOKEN_BATCH_OPTIONS, compact=True)}",
429
  )
 
 
 
 
 
 
 
 
430
  parser.add_argument(
431
  "-level",
432
  "--log_level",
@@ -460,6 +468,12 @@ if __name__ == "__main__":
460
  logger.info(f"Adding token batch option {args.token_batch_option} to the list")
461
  TOKEN_BATCH_OPTIONS.append(args.token_batch_option)
462
 
 
 
 
 
 
 
463
  logger.info("Loading OCR model")
464
  with contextlib.redirect_stdout(None):
465
  ocr_model = ocr_predictor(
 
427
  default=None,
428
  help=f"Add a token batch size to the demo UI options, default: {pp.pformat(TOKEN_BATCH_OPTIONS, compact=True)}",
429
  )
430
+ parser.add_argument(
431
+ "-max_agg",
432
+ "-2x",
433
+ "--aggregator_beam_boost",
434
+ dest="aggregator_beam_boost",
435
+ action="store_true",
436
+ help="Double the number of beams for the aggregator during beam search",
437
+ )
438
  parser.add_argument(
439
  "-level",
440
  "--log_level",
 
468
  logger.info(f"Adding token batch option {args.token_batch_option} to the list")
469
  TOKEN_BATCH_OPTIONS.append(args.token_batch_option)
470
 
471
+ if args.aggregator_beam_boost:
472
+ logger.info("Doubling aggregator num_beams")
473
+ _agg_cfg = aggregator.get_generation_config()
474
+ _agg_cfg["num_beams"] = _agg_cfg["num_beams"] * 2
475
+ aggregator.update_generation_config(**_agg_cfg)
476
+
477
  logger.info("Loading OCR model")
478
  with contextlib.redirect_stdout(None):
479
  ocr_model = ocr_predictor(