tjhalanigrid commited on
Commit
c86b5fc
·
1 Parent(s): 4274ab3
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -543,8 +543,16 @@ with gr.Blocks(title="Text-to-SQL RLHF") as demo:
543
 
544
  with gr.Accordion("Task 1: Parallel Reward Benchmark", open=False):
545
  gr.Markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*")
 
 
 
 
546
  t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)")
547
- t1_workers = gr.Number(value=10, precision=0, label="Max workers")
 
 
 
 
548
  t1_run = gr.Button("Run Task 1 benchmark")
549
  t1_out = gr.Textbox(label="Output", lines=12)
550
  t1_plot = gr.HTML(label="Plot (if generated)")
 
543
 
544
  with gr.Accordion("Task 1: Parallel Reward Benchmark", open=False):
545
  gr.Markdown("*(Simulates the heavy RLHF training workload by running hundreds of complex SQL queries concurrently to test SQLite multi-threading performance.)*")
546
+ # *******************
547
+ core_count = os.cpu_count() or 2
548
+ smart_worker_default = min(32, core_count * 2)
549
+
550
  t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)")
551
+ t1_workers = gr.Number(value=smart_worker_default, precision=0, label=f"Max workers (Auto-detected: {core_count} cores)")
552
+ #****************
553
+ # t1_n = gr.Number(value=20, precision=0, label="Rollouts (n)")
554
+
555
+ # t1_workers = gr.Number(value=10, precision=0, label="Max workers")
556
  t1_run = gr.Button("Run Task 1 benchmark")
557
  t1_out = gr.Textbox(label="Output", lines=12)
558
  t1_plot = gr.HTML(label="Plot (if generated)")