Jae-Won Chung AmberLJC commited on
Commit
8ff63e4
1 Parent(s): b08a0ac

The ML.ENERGY Colosseum (#22)

Browse files

Co-authored-by: AmberLJC <amberljc@umich.edu>

.gitignore CHANGED
@@ -7,3 +7,12 @@
7
  # Editor
8
  pyrightconfig.json
9
  .idea
 
 
 
 
 
 
 
 
 
 
7
  # Editor
8
  pyrightconfig.json
9
  .idea
10
+
11
+ # Python
12
+ *.egg-info
13
+ **/__pycache__
14
+ build/
15
+
16
+ # Data files
17
+ *.log
18
+ pegasus/consumed.yaml
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: "⚡"
4
  python_version: "3.9"
5
  app_file: "app.py"
6
  sdk: "gradio"
7
- sdk_version: "3.35.2"
8
  pinned: true
9
  tags: ["energy", "leaderboard"]
10
  colorFrom: "black"
@@ -22,7 +22,12 @@ How much energy do LLMs consume?
22
  This README focuses on explaining how to run the benchmark yourself.
23
  The actual leaderboard is here: https://ml.energy/leaderboard.
24
 
25
- ## Setup
 
 
 
 
 
26
 
27
  ### Model weights
28
 
 
4
  python_version: "3.9"
5
  app_file: "app.py"
6
  sdk: "gradio"
7
+ sdk_version: "3.39.0"
8
  pinned: true
9
  tags: ["energy", "leaderboard"]
10
  colorFrom: "black"
 
22
  This README focuses on explaining how to run the benchmark yourself.
23
  The actual leaderboard is here: https://ml.energy/leaderboard.
24
 
25
+ ## Colosseum
26
+
27
+ We instrumented [Hugging Face TGI](https://github.com/huggingface/text-generation-inference) so that it measures and returns GPU energy consumption.
28
+ Then, our [controller](/spitfight/colosseum/controller) server receives user prompts from the [Gradio app](/app.py), selects two models randomly, and streams model responses back with energy consumption.
29
+
30
+ ## Setup for benchmarking
31
 
32
  ### Model weights
33
 
app.py CHANGED
@@ -5,6 +5,9 @@ import yaml
5
  import requests
6
  import itertools
7
  import contextlib
 
 
 
8
  from dateutil import parser, tz
9
 
10
  import numpy as np
@@ -13,9 +16,10 @@ import pandas as pd
13
  import plotly.io as pio
14
  import plotly.express as px
15
  from pandas.api.types import is_numeric_dtype, is_float_dtype
16
-
17
  pio.templates.default = "plotly_white"
18
 
 
 
19
 
20
  class TableManager:
21
  def __init__(self, data_dir: str) -> None:
@@ -215,7 +219,6 @@ class TableManager:
215
 
216
  return fig, width, height, ""
217
 
218
-
219
  # The global instance of the TableManager should only be used when
220
  # initializing components in the Gradio interface. If the global instance
221
  # is mutated while handling user sessions, the change will be reflected
@@ -280,7 +283,7 @@ function format_model_link() {{
280
  """
281
 
282
  # Custom CSS.
283
- css = """
284
  /* Make ML.ENERGY look like a clickable logo. */
285
  .text-logo {
286
  color: #23d175 !important;
@@ -311,6 +314,14 @@ table th:first-child {
311
  .tab-nav > button {
312
  font-size: 18px !important;
313
  }
 
 
 
 
 
 
 
 
314
  """
315
 
316
  intro_text = """
@@ -324,13 +335,262 @@ including the ARC Challenge (reasoning), HellaSwag (common sense), and TruthfulQ
324
  Every benchmark is limited in some sense -- Before you interpret the results, please take a look at the <b>Limitations</b> section there, too.</p>
325
  """
326
 
327
- block = gr.Blocks(css=css)
328
- with block:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  tbm = gr.State(global_tbm) # type: ignore
330
  with gr.Box():
331
  gr.HTML("<h1><a href='https://ml.energy' class='text-logo'>ML.ENERGY</a> Leaderboard</h1>")
332
 
333
  with gr.Tabs():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  # Tab: Leaderboard.
335
  with gr.Tab("Leaderboard"):
336
  with gr.Box():
@@ -340,7 +600,7 @@ with block:
340
  with gr.Row():
341
  with gr.Box():
342
  gr.Markdown("### Benchmark results to show")
343
- checkboxes = []
344
  for key, choices in global_tbm.schema.items():
345
  # Specifying `value` makes everything checked by default.
346
  checkboxes.append(gr.CheckboxGroup(choices=choices, value=choices[:1], label=key))
@@ -349,10 +609,10 @@ with block:
349
  with gr.Row():
350
  dataframe = gr.Dataframe(type="pandas", elem_id="tab-leaderboard")
351
  # Make sure the models have clickable links.
352
- dataframe.change(None, None, None, _js=dataframe_update_js)
353
  # Table automatically updates when users check or uncheck any checkbox.
354
  for checkbox in checkboxes:
355
- checkbox.change(TableManager.set_filter_get_df, inputs=[tbm, *checkboxes], outputs=dataframe)
356
 
357
  # Block: Allow users to add new columns.
358
  with gr.Box():
@@ -381,21 +641,25 @@ with block:
381
  TableManager.add_column,
382
  inputs=[tbm, colname_input, formula_input],
383
  outputs=[dataframe, add_col_message],
 
384
  )
385
  formula_input.submit(
386
  TableManager.add_column,
387
  inputs=[tbm, colname_input, formula_input],
388
  outputs=[dataframe, add_col_message],
 
389
  )
390
  add_col_btn.click(
391
  TableManager.add_column,
392
  inputs=[tbm, colname_input, formula_input],
393
  outputs=[dataframe, add_col_message],
 
394
  )
395
  clear_input_btn.click(
396
  lambda: (None, None, None),
397
  inputs=None,
398
  outputs=[colname_input, formula_input, add_col_message],
 
399
  )
400
 
401
  # Block: Allow users to plot 2D and 3D scatter plots.
@@ -425,42 +689,51 @@ with block:
425
  )[0]) # type: ignore
426
  with gr.Row():
427
  plot_message = gr.HTML("")
428
- add_col_btn.click(TableManager.update_dropdown, inputs=tbm, outputs=axis_dropdowns) # type: ignore
429
  plot_width_input.submit(
430
  TableManager.plot_scatter,
431
  inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
432
  outputs=[plot, plot_width_input, plot_height_input, plot_message],
 
433
  )
434
  plot_height_input.submit(
435
  TableManager.plot_scatter,
436
  inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
437
  outputs=[plot, plot_width_input, plot_height_input, plot_message],
 
438
  )
439
  plot_btn.click(
440
  TableManager.plot_scatter,
441
  inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
442
  outputs=[plot, plot_width_input, plot_height_input, plot_message],
 
443
  )
444
  clear_plot_btn.click(
445
  lambda: (None,) * 7,
446
  None,
447
  outputs=[*axis_dropdowns, plot, plot_width_input, plot_height_input, plot_message],
 
448
  )
449
 
450
  # Block: Leaderboard date.
451
  with gr.Row():
452
  gr.HTML(f"<h3 style='color: gray'>Last updated: {current_date}</h3>")
453
 
454
- # Tab: Online demo.
455
- with gr.Tab("Online demo (Coming in August!)"):
456
- gr.Markdown("# Online demo with real time energy measurements\n\nComing soon in August!")
457
-
458
  # Tab: About page.
459
  with gr.Tab("About"):
460
  # Read in LEADERBOARD.md
461
- gr.Markdown(open("LEADERBOARD.md").read())
462
 
463
  # Load the table on page load.
464
  block.load(lambda: global_tbm.set_filter_get_df(), outputs=dataframe)
465
 
466
- block.launch()
 
 
 
 
 
 
 
 
 
 
5
  import requests
6
  import itertools
7
  import contextlib
8
+ import argparse
9
+ import os
10
+ from typing import Literal
11
  from dateutil import parser, tz
12
 
13
  import numpy as np
 
16
  import plotly.io as pio
17
  import plotly.express as px
18
  from pandas.api.types import is_numeric_dtype, is_float_dtype
 
19
  pio.templates.default = "plotly_white"
20
 
21
+ from spitfight.colosseum.client import ControllerClient
22
+
23
 
24
  class TableManager:
25
  def __init__(self, data_dir: str) -> None:
 
219
 
220
  return fig, width, height, ""
221
 
 
222
  # The global instance of the TableManager should only be used when
223
  # initializing components in the Gradio interface. If the global instance
224
  # is mutated while handling user sessions, the change will be reflected
 
283
  """
284
 
285
  # Custom CSS.
286
+ custom_css = """
287
  /* Make ML.ENERGY look like a clickable logo. */
288
  .text-logo {
289
  color: #23d175 !important;
 
314
  .tab-nav > button {
315
  font-size: 18px !important;
316
  }
317
+
318
+ /* Color texts. */
319
+ .green-text {
320
+ color: #23d175 !important;
321
+ }
322
+ .red-text {
323
+ color: #ff3860 !important;
324
+ }
325
  """
326
 
327
  intro_text = """
 
335
  Every benchmark is limited in some sense -- Before you interpret the results, please take a look at the <b>Limitations</b> section there, too.</p>
336
  """
337
 
338
+ # The app will not start without a controller address set.
339
+ controller_addr = os.environ["COLOSSEUM_CONTROLLER_ADDR"]
340
+ global_controller_client = ControllerClient(controller_addr=controller_addr, timeout=15)
341
+
342
+ ANONYMOUS_MODEL_TEXT = "## Anonymous 🤫"
343
+
344
+ # Colosseum helper functions.
345
+ def enable_interact():
346
+ return [gr.update(interactive=True)] * 2
347
+
348
+ def disable_interact():
349
+ return [gr.update(interactive=False)] * 2
350
+
351
+ def consumed_less_energy_message(energy_a, energy_b):
352
+ """Return a message that indicates that the user chose the model that consumed less energy.
353
+
354
+ By default report in "%f %" but if the difference is larger than 2 times, report in "%f X".
355
+ """
356
+ less_energy = min(energy_a, energy_b)
357
+ more_energy = max(energy_a, energy_b)
358
+ factor = less_energy / more_energy
359
+ if factor <= 0.5:
360
+ message = f"<h2>That response also <span class='green-text'>consumed {1/factor:.1f}X less energy</span>!</h2>"
361
+ else:
362
+ message = f"<h2>That response also <span class='green-text'>consumed {100 - factor * 100:.1f}% less energy</span>!</h2>"
363
+ return message
364
+
365
+ def consumed_more_energy_message(energy_a, energy_b):
366
+ """Return a message that indicates that the user chose the model that consumed more energy.
367
+
368
+ By default report in "%f %" but if the difference is larger than 2 times, report in "%f X".
369
+ """
370
+ less_energy = min(energy_a, energy_b)
371
+ more_energy = max(energy_a, energy_b)
372
+ factor = more_energy / less_energy
373
+ if factor >= 2.0:
374
+ message = f"<h2>That response <span class='red-text'>consumed {factor:.1f}x more energy</span>.</h2>"
375
+ else:
376
+ message = f"<h2>That response <span class='red-text'>consumed {factor * 100 - 100:.1f}% more energy</span>.</h2>"
377
+ return message
378
+
379
+ # Colosseum event handlers
380
+ def add_prompt_disable_submit(prompt, history_a, history_b):
381
+ """Add the user's prompt to the two model's history and disable the submit button."""
382
+ client = global_controller_client.fork()
383
+ return [
384
+ gr.Textbox.update(value=" ", interactive=False),
385
+ gr.Button.update(interactive=False),
386
+ history_a + [[prompt, ""]],
387
+ history_b + [[prompt, ""]],
388
+ client,
389
+ ]
390
+
391
+ def generate_responses(client: ControllerClient, history_a, history_b):
392
+ """Generate responses for the two models."""
393
+ for resp_a, resp_b in itertools.zip_longest(
394
+ client.prompt(prompt=history_a[-1][0], index=0),
395
+ client.prompt(prompt=history_b[-1][0], index=1),
396
+ ):
397
+ if resp_a is not None:
398
+ history_a[-1][1] += resp_a
399
+ if resp_b is not None:
400
+ history_b[-1][1] += resp_b
401
+ yield [history_a, history_b]
402
+
403
+ def make_resp_vote_func(victory_index: Literal[0, 1]):
404
+ """Return a function that will be called when the user clicks on response preference vote buttons."""
405
+ def resp_vote_func(client: ControllerClient):
406
+ vote_response = client.response_vote(victory_index=victory_index)
407
+ model_name_a, model_name_b = map(lambda n: f"## {n}", vote_response.model_names)
408
+ energy_a, energy_b = vote_response.energy_consumptions
409
+ # User liked the model that also consumed less energy.
410
+ if (victory_index == 0 and energy_a <= energy_b) or (victory_index == 1 and energy_a >= energy_b):
411
+ energy_message = consumed_less_energy_message(energy_a, energy_b)
412
+ return [
413
+ # Disable response vote buttons
414
+ gr.Button.update(interactive=False), gr.Button.update(interactive=False),
415
+ # Reveal model names
416
+ gr.Markdown.update(model_name_a), gr.Markdown.update(model_name_b),
417
+ # Display energy consumption comparison message
418
+ gr.Markdown.update(energy_message, visible=True),
419
+ # Keep energy vote buttons hidden
420
+ gr.Button.update(visible=False, interactive=False), gr.Button.update(visible=False, interactive=False),
421
+ # Enable reset button
422
+ gr.Button.update(visible=True, interactive=True),
423
+ ]
424
+ # User liked the model that consumed more energy.
425
+ else:
426
+ energy_message = consumed_more_energy_message(energy_a, energy_b)
427
+ return [
428
+ # Disable response vote buttons
429
+ gr.Button.update(interactive=False), gr.Button.update(interactive=False),
430
+ # Leave model names hidden
431
+ gr.Markdown.update(ANONYMOUS_MODEL_TEXT), gr.Markdown.update(ANONYMOUS_MODEL_TEXT),
432
+ # Display energy consumption comparison message
433
+ gr.Markdown.update(energy_message, visible=True),
434
+ # Reveal and enable energy vote buttons
435
+ gr.Button.update(visible=True, interactive=True), gr.Button.update(visible=True, interactive=True),
436
+ # Keep the reset button disabled
437
+ gr.Button.update(visible=False, interactive=False),
438
+ ]
439
+ return resp_vote_func
440
+
441
+ def make_energy_vote_func(is_worth: bool):
442
+ """Return a function that will be called when the user clicks on energy vote buttons."""
443
+ def energy_vote_func(client: ControllerClient, energy_message: str):
444
+ vote_response = client.energy_vote(is_worth=is_worth)
445
+ model_name_a, model_name_b = map(lambda n: f"## {n}", vote_response.model_names)
446
+ return [
447
+ # Reveal model names
448
+ gr.Markdown.update(model_name_a), gr.Markdown.update(model_name_b),
449
+ # Disable energy vote buttons
450
+ gr.Button.update(interactive=False), gr.Button.update(interactive=False),
451
+ # Enable reset button
452
+ gr.Button.update(interactive=True, visible=True),
453
+ # Append to the energy comparison message
454
+ energy_message[:-5] + (" Fair enough.</h2>" if is_worth else " Wasn't worth it.</h2>"),
455
+ ]
456
+ return energy_vote_func
457
+
458
+ def play_again():
459
+ return [
460
+ # Clear chatbot history
461
+ None, None,
462
+ # Turn on prompt textbox and submit button
463
+ gr.Textbox.update(value="", interactive=True), gr.Button.update(interactive=True),
464
+ # Mask model names
465
+ gr.Markdown.update(ANONYMOUS_MODEL_TEXT),
466
+ gr.Markdown.update(ANONYMOUS_MODEL_TEXT),
467
+ # Hide energy vote buttons and message
468
+ gr.Button.update(visible=False), gr.Button.update(visible=False), gr.Markdown.update(visible=False),
469
+ # Disable reset button
470
+ gr.Button.update(interactive=False, visible=False),
471
+ ]
472
+
473
+ focus_prompt_input_js = """
474
+ function() {
475
+ for (let textarea of document.getElementsByTagName("textarea")) {
476
+ if (textarea.hasAttribute("autofocus")) {
477
+ textarea.focus();
478
+ return;
479
+ }
480
+ }
481
+ }
482
+ """
483
+
484
+ with gr.Blocks(css=custom_css) as block:
485
  tbm = gr.State(global_tbm) # type: ignore
486
  with gr.Box():
487
  gr.HTML("<h1><a href='https://ml.energy' class='text-logo'>ML.ENERGY</a> Leaderboard</h1>")
488
 
489
  with gr.Tabs():
490
+ # Tab: Colosseum.
491
+ with gr.TabItem("Colosseum ⚔️️"):
492
+ gr.Markdown(open("docs/colosseum_top.md").read())
493
+
494
+ with gr.Group():
495
+ with gr.Row():
496
+ prompt_input = gr.Textbox(
497
+ show_label=False,
498
+ placeholder="Type your prompt and press ENTER",
499
+ autofocus=True,
500
+ container=False,
501
+ scale=20,
502
+ elem_id="prompt-textarea",
503
+ )
504
+ prompt_submit_btn = gr.Button(
505
+ value="⚔️️ Fight!",
506
+ elem_classes=["btn-submit"],
507
+ min_width=60,
508
+ scale=1,
509
+ )
510
+
511
+ with gr.Row():
512
+ masked_model_names = []
513
+ chatbots = []
514
+ resp_vote_btn_list: list[gr.component.Component] = []
515
+ with gr.Column():
516
+ with gr.Row():
517
+ masked_model_names.append(gr.Markdown(ANONYMOUS_MODEL_TEXT))
518
+ with gr.Row():
519
+ chatbots.append(gr.Chatbot(label="Model A", elem_id="chatbot", height=600))
520
+ with gr.Row():
521
+ left_resp_vote_btn = gr.Button(value="👈 Model A is better", interactive=False)
522
+ resp_vote_btn_list.append(left_resp_vote_btn)
523
+
524
+ with gr.Column():
525
+ with gr.Row():
526
+ masked_model_names.append(gr.Markdown(ANONYMOUS_MODEL_TEXT))
527
+ with gr.Row():
528
+ chatbots.append(gr.Chatbot(label="Model B", elem_id="chatbot", height=600))
529
+ with gr.Row():
530
+ right_resp_vote_btn = gr.Button(value="👉 Model B is better", interactive=False)
531
+ resp_vote_btn_list.append(right_resp_vote_btn)
532
+
533
+ with gr.Row():
534
+ energy_comparison_message = gr.HTML(visible=False)
535
+
536
+ with gr.Row():
537
+ worth_energy_vote_btn = gr.Button(value="The better response was worth the extra energy.", visible=False)
538
+ notworth_energy_vote_btn = gr.Button(value="Not really worth it.", visible=False)
539
+ energy_vote_btn_list: list[gr.component.Component] = [worth_energy_vote_btn, notworth_energy_vote_btn]
540
+
541
+ with gr.Row():
542
+ play_again_btn = gr.Button("Play again!", visible=False)
543
+
544
+ gr.Markdown(open("docs/colosseum_bottom.md").read())
545
+
546
+ controller_client = gr.State()
547
+
548
+
549
+ (prompt_input
550
+ .submit(add_prompt_disable_submit, [prompt_input, *chatbots], [prompt_input, prompt_submit_btn, *chatbots, controller_client], queue=False)
551
+ .then(generate_responses, [controller_client, *chatbots], [*chatbots], queue=True)
552
+ .then(enable_interact, None, resp_vote_btn_list, queue=False))
553
+ (prompt_submit_btn
554
+ .click(add_prompt_disable_submit, [prompt_input, *chatbots], [prompt_input, prompt_submit_btn, *chatbots, controller_client], queue=False)
555
+ .then(generate_responses, [controller_client, *chatbots], [*chatbots], queue=True)
556
+ .then(enable_interact, None, resp_vote_btn_list, queue=False))
557
+
558
+ left_resp_vote_btn.click(
559
+ make_resp_vote_func(victory_index=0),
560
+ [controller_client],
561
+ [*resp_vote_btn_list, *masked_model_names, energy_comparison_message, *energy_vote_btn_list, play_again_btn],
562
+ queue=False,
563
+ )
564
+ right_resp_vote_btn.click(
565
+ make_resp_vote_func(victory_index=1),
566
+ [controller_client],
567
+ [*resp_vote_btn_list, *masked_model_names, energy_comparison_message, *energy_vote_btn_list, play_again_btn],
568
+ queue=False,
569
+ )
570
+
571
+ worth_energy_vote_btn.click(
572
+ make_energy_vote_func(is_worth=True),
573
+ [controller_client, energy_comparison_message],
574
+ [*masked_model_names, *energy_vote_btn_list, play_again_btn, energy_comparison_message],
575
+ queue=False,
576
+ )
577
+ notworth_energy_vote_btn.click(
578
+ make_energy_vote_func(is_worth=False),
579
+ [controller_client, energy_comparison_message],
580
+ [*masked_model_names, *energy_vote_btn_list, play_again_btn, energy_comparison_message],
581
+ queue=False,
582
+ )
583
+
584
+ (play_again_btn
585
+ .click(
586
+ play_again,
587
+ None,
588
+ [*chatbots, prompt_input, prompt_submit_btn, *masked_model_names, *energy_vote_btn_list, energy_comparison_message, play_again_btn],
589
+ queue=False,
590
+ )
591
+ .then(None, _js=focus_prompt_input_js, queue=False))
592
+
593
+
594
  # Tab: Leaderboard.
595
  with gr.Tab("Leaderboard"):
596
  with gr.Box():
 
600
  with gr.Row():
601
  with gr.Box():
602
  gr.Markdown("### Benchmark results to show")
603
+ checkboxes: list[gr.CheckboxGroup] = []
604
  for key, choices in global_tbm.schema.items():
605
  # Specifying `value` makes everything checked by default.
606
  checkboxes.append(gr.CheckboxGroup(choices=choices, value=choices[:1], label=key))
 
609
  with gr.Row():
610
  dataframe = gr.Dataframe(type="pandas", elem_id="tab-leaderboard")
611
  # Make sure the models have clickable links.
612
+ dataframe.change(None, None, None, _js=dataframe_update_js, queue=False)
613
  # Table automatically updates when users check or uncheck any checkbox.
614
  for checkbox in checkboxes:
615
+ checkbox.change(TableManager.set_filter_get_df, inputs=[tbm, *checkboxes], outputs=dataframe, queue=False)
616
 
617
  # Block: Allow users to add new columns.
618
  with gr.Box():
 
641
  TableManager.add_column,
642
  inputs=[tbm, colname_input, formula_input],
643
  outputs=[dataframe, add_col_message],
644
+ queue=False,
645
  )
646
  formula_input.submit(
647
  TableManager.add_column,
648
  inputs=[tbm, colname_input, formula_input],
649
  outputs=[dataframe, add_col_message],
650
+ queue=False,
651
  )
652
  add_col_btn.click(
653
  TableManager.add_column,
654
  inputs=[tbm, colname_input, formula_input],
655
  outputs=[dataframe, add_col_message],
656
+ queue=False,
657
  )
658
  clear_input_btn.click(
659
  lambda: (None, None, None),
660
  inputs=None,
661
  outputs=[colname_input, formula_input, add_col_message],
662
+ queue=False,
663
  )
664
 
665
  # Block: Allow users to plot 2D and 3D scatter plots.
 
689
  )[0]) # type: ignore
690
  with gr.Row():
691
  plot_message = gr.HTML("")
692
+ add_col_btn.click(TableManager.update_dropdown, inputs=tbm, outputs=axis_dropdowns, queue=False) # type: ignore
693
  plot_width_input.submit(
694
  TableManager.plot_scatter,
695
  inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
696
  outputs=[plot, plot_width_input, plot_height_input, plot_message],
697
+ queue=False,
698
  )
699
  plot_height_input.submit(
700
  TableManager.plot_scatter,
701
  inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
702
  outputs=[plot, plot_width_input, plot_height_input, plot_message],
703
+ queue=False,
704
  )
705
  plot_btn.click(
706
  TableManager.plot_scatter,
707
  inputs=[tbm, plot_width_input, plot_height_input, *axis_dropdowns],
708
  outputs=[plot, plot_width_input, plot_height_input, plot_message],
709
+ queue=False,
710
  )
711
  clear_plot_btn.click(
712
  lambda: (None,) * 7,
713
  None,
714
  outputs=[*axis_dropdowns, plot, plot_width_input, plot_height_input, plot_message],
715
+ queue=False,
716
  )
717
 
718
  # Block: Leaderboard date.
719
  with gr.Row():
720
  gr.HTML(f"<h3 style='color: gray'>Last updated: {current_date}</h3>")
721
 
 
 
 
 
722
  # Tab: About page.
723
  with gr.Tab("About"):
724
  # Read in LEADERBOARD.md
725
+ gr.Markdown(open("docs/leaderboard.md").read())
726
 
727
  # Load the table on page load.
728
  block.load(lambda: global_tbm.set_filter_get_df(), outputs=dataframe)
729
 
730
+
731
+ if __name__ == "__main__":
732
+ parser = argparse.ArgumentParser()
733
+ parser.add_argument("--share", action="store_true", help="Specify if sharing is enabled")
734
+ parser.add_argument("--concurrency", type=int, default=10)
735
+
736
+ args = parser.parse_args()
737
+ block.queue(
738
+ concurrency_count=args.concurrency, status_update_rate=10, api_open=False
739
+ ).launch(share=args.share, show_error=True)
Dockerfile → deployment/benchmark.Dockerfile RENAMED
File without changes
deployment/controller-container.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ docker run \
4
+ --name controller \
5
+ --net leaderboard \
6
+ -v $HOME/workspace/leaderboard:/workspace/leaderboard \
7
+ -v $HOME/workspace/text-generation-inference/deployment:/workspace/text-generation-inference/deployment:ro \
8
+ -v /data/leaderboard/colosseum-controller-logs:/logs \
9
+ -p 7778:8000 \
10
+ -e LOG_DIR=/logs \
11
+ mlenergy/colosseum-controller:latest
deployment/controller.Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ubuntu:22.04
2
+
3
+ # Basic installs
4
+ ARG DEBIAN_FRONTEND=noninteractive
5
+ ENV TZ='America/Detroit'
6
+ RUN apt-get update -qq \
7
+ && apt-get -y --no-install-recommends install \
8
+ tzdata software-properties-common wget git \
9
+ && apt-get clean all \
10
+ && rm -r /var/lib/apt/lists/* \
11
+ && ln -fs /usr/share/zoneinfo/America/Detroit /etc/localtime \
12
+ && dpkg-reconfigure -f noninteractive tzdata
13
+
14
+ # Install Miniconda3 23.3.1
15
+ ENV PATH="/root/.local/miniconda3/bin:$PATH"
16
+ RUN mkdir -p /root/.local \
17
+ && wget https://repo.anaconda.com/miniconda/Miniconda3-py39_23.3.1-0-Linux-x86_64.sh \
18
+ && mkdir /root/.conda \
19
+ && bash Miniconda3-py39_23.3.1-0-Linux-x86_64.sh -b -p /root/.local/miniconda3 \
20
+ && rm -f Miniconda3-py39_23.3.1-0-Linux-x86_64.sh \
21
+ && ln -sf /root/.local/miniconda3/etc/profile.d/conda.sh /etc/profile.d/conda.sh
22
+
23
+ # Install spitfight
24
+ ADD . /workspace/leaderboard
25
+ RUN cd /workspace/leaderboard \
26
+ && pip install -e .[colosseum-controller]
27
+
28
+ WORKDIR /workspace/leaderboard
29
+
30
+ CMD ["python", "spitfight/colosseum/controller/router.py"]
deployment/docker-compose-0.yaml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ Falcon-7B:
3
+ container_name: worker0
4
+ image: mlenergy/tgi:latest
5
+ command: ["--model-id", "tiiuae/falcon-7b-instruct", "--num-shard", "1", "--otlp-endpoint", "http://jaeger:4317"]
6
+ shm_size: 1g
7
+ networks:
8
+ - leaderboard
9
+ volumes:
10
+ - /data/leaderboard/tgi-data:/data
11
+ deploy:
12
+ resources:
13
+ reservations:
14
+ devices:
15
+ - driver: nvidia
16
+ device_ids: ["0"]
17
+ capabilities: [gpu]
18
+ Llama2-7B:
19
+ container_name: worker1
20
+ image: mlenergy/tgi:latest
21
+ command: ["--model-id", "/weights/metaai/Llama-2-7b-chat-hf", "--num-shard", "1", "--otlp-endpoint", "http://jaeger:4317"]
22
+ shm_size: 1g
23
+ networks:
24
+ - leaderboard
25
+ volumes:
26
+ - /data/leaderboard/tgi-data:/data
27
+ - /data/leaderboard/weights:/weights
28
+ deploy:
29
+ resources:
30
+ reservations:
31
+ devices:
32
+ - driver: nvidia
33
+ device_ids: ["1"]
34
+ capabilities: [gpu]
35
+ FastChat-T5-3B:
36
+ container_name: worker2
37
+ image: mlenergy/tgi:latest
38
+ command: ["--model-id", "lmsys/fastchat-t5-3b-v1.0", "--num-shard", "1", "--otlp-endpoint", "http://jaeger:4317"]
39
+ environment:
40
+ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION: python
41
+ shm_size: 1g
42
+ networks:
43
+ - leaderboard
44
+ volumes:
45
+ - /data/leaderboard/tgi-data:/data
46
+ deploy:
47
+ resources:
48
+ reservations:
49
+ devices:
50
+ - driver: nvidia
51
+ device_ids: ["2"]
52
+ capabilities: [gpu]
53
+ Llama2-13B:
54
+ container_name: worker3
55
+ image: mlenergy/tgi:latest
56
+ command: ["--model-id", "/weights/metaai/Llama-2-13b-chat-hf", "--num-shard", "1", "--otlp-endpoint", "http://jaeger:4317"]
57
+ shm_size: 1g
58
+ networks:
59
+ - leaderboard
60
+ volumes:
61
+ - /data/leaderboard/tgi-data:/data
62
+ - /data/leaderboard/weights:/weights
63
+ deploy:
64
+ resources:
65
+ reservations:
66
+ devices:
67
+ - driver: nvidia
68
+ device_ids: ["3"]
69
+ capabilities: [gpu]
70
+
71
+ networks:
72
+ leaderboard:
73
+ name: leaderboard
74
+ external: true
deployment/docker-compose-1.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ Llama2-70B-INT8:
3
+ container_name: worker4
4
+ image: mlenergy/tgi:latest
5
+ command: ["--model-id", "meta-llama/Llama-2-70b-chat-hf", "--num-shard", "2", "--otlp-endpoint", "http://jaeger:4317", "--quantize", "bitsandbytes"]
6
+ shm_size: 1g
7
+ environment:
8
+ HUGGING_FACE_HUB_TOKEN: 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
9
+ networks:
10
+ - leaderboard
11
+ volumes:
12
+ - /data/leaderboard/tgi-data:/data
13
+ deploy:
14
+ resources:
15
+ reservations:
16
+ devices:
17
+ - driver: nvidia
18
+ device_ids: ["0", "1"]
19
+ capabilities: [gpu]
20
+ Falcon-40B:
21
+ container_name: worker5
22
+ image: mlenergy/tgi:latest
23
+ command: ["--model-id", "tiiuae/falcon-40b-instruct", "--num-shard", "2", "--otlp-endpoint", "http://jaeger:4317"]
24
+ shm_size: 1g
25
+ networks:
26
+ - leaderboard
27
+ volumes:
28
+ - /data/leaderboard/tgi-data:/data
29
+ deploy:
30
+ resources:
31
+ reservations:
32
+ devices:
33
+ - driver: nvidia
34
+ device_ids: ["2", "3"]
35
+ capabilities: [gpu]
36
+
37
+ networks:
38
+ leaderboard:
39
+ name: leaderboard
40
+ external: true
docs/colosseum_bottom.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Terms of use
2
+
3
+ By using our service, you agree to these Terms of Use and accept that the Service provides an approximate estimation of model inference energy usage for research purposes only. We are not liable for any damages or loss incurred by you or any third party arising from the use of the Service. It may generate offensive content and offers limited safety measures, thus should not be used for any illegal, harmful, violent, racist, or sexual purposes. The service collects user dialogue data and voting results. We reserve the right to distribute the dataset in the future.
4
+
5
+ ### Technical details
6
+
7
+ - We allow models to generate only up to 512 new tokens. Due to this, some responses may be cut off in the middle.
8
+ - Tokens are sampled from the model output with `temperature` 1.0, `repetition_penalty` 1.0, `top_k` 50, and `top_p` 0.95.
9
+ - Large models (>= 30B) run on two NVIDIA A40 GPUs with tensor parallelism, whereas other models run on one NVIDIA A40 GPU. We directly measure the energy consumption of these GPUs.
10
+
11
+ ### Contact
12
+
13
+ Please direct general questions and issues related to the Colosseum to our GitHub repository's [discussion board](https://github.com/ml-energy/leaderboard/discussions).
14
+ You can find the ML.ENERGY initiative members in [our homepage](https://ml.energy#members).
docs/colosseum_top.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ > Enter the ML.ENERGY Colosseum, where language models duel with intellect, and your judgment tips the scales of victory.
2
+
3
+ ### Rules of the Colosseum
4
+
5
+ - As the spectator, you'll decide the fates of two anonymous language models -- our gladiators.
6
+ - Your role is twofold: First, you vote for the model that delivered the best response to your prompt.
7
+ - Next, mighty [Zeus](https://ml.energy/zeus) will reveal which language model consumed more energy. Evaluate if its performance justified the energy consumption.
8
+ - Only after you cast votes will the models' identities be unveiled.
LEADERBOARD.md → docs/leaderboard.md RENAMED
@@ -3,7 +3,7 @@ The goal of the ML.ENERGY Leaderboard is to give people a sense of how much **en
3
  The code for the leaderboard, backing data, and scripts for benchmarking are all open-source in our [repository](https://github.com/ml-energy/leaderboard).
4
  We'll see you at the [Discussion board](https://github.com/ml-energy/leaderboard/discussions), where you can ask questions, suggest improvement ideas, or just discuss leaderboard results!
5
 
6
- ## Columns
7
 
8
  - `gpu`: NVIDIA GPU model name.
9
  - `task`: Name of the task. See *Tasks* below for details.
@@ -113,7 +113,7 @@ By doing this, we can provide numbers for reasonable comparison without being ti
113
 
114
  This leaderboard is a research preview intended for non-commercial use only.
115
  Model weights were taken as is from the Hugging Face Hub if available and are subject to their licenses.
116
- The use of LLaMA weights are subject to their [license](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md).
117
  Please direct inquiries/reports of potential violation to Jae-Won Chung.
118
 
119
  ## Acknowledgements
 
3
  The code for the leaderboard, backing data, and scripts for benchmarking are all open-source in our [repository](https://github.com/ml-energy/leaderboard).
4
  We'll see you at the [Discussion board](https://github.com/ml-energy/leaderboard/discussions), where you can ask questions, suggest improvement ideas, or just discuss leaderboard results!
5
 
6
+ ## Leaderboard Columns
7
 
8
  - `gpu`: NVIDIA GPU model name.
9
  - `task`: Name of the task. See *Tasks* below for details.
 
113
 
114
  This leaderboard is a research preview intended for non-commercial use only.
115
  Model weights were taken as is from the Hugging Face Hub if available and are subject to their licenses.
116
+ The use of Llama weights are subject to their [license](https://github.com/facebookresearch/llama/blob/main/LICENSE).
117
  Please direct inquiries/reports of potential violation to Jae-Won Chung.
118
 
119
  ## Acknowledgements
requirements.txt CHANGED
@@ -1,2 +1 @@
1
- plotly==5.15.0
2
- gradio==3.35.2
 
1
+ .[app]
 
setup.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ extras_require = {
4
+ "colosseum-controller": [
5
+ "fastapi",
6
+ "fschat==0.2.23",
7
+ "text_generation @ git+https://github.com/ml-energy/text_generation_energy@master",
8
+ ],
9
+ "app": ["plotly==5.15.0", "gradio==3.39.0", "pydantic==1.10.9"],
10
+ "benchmark": ["zeus-ml", "fschat==0.2.23", "tyro", "rich"],
11
+ }
12
+
13
+ extras_require["all"] = list(set(sum(extras_require.values(), [])))
14
+
15
+ setup(
16
+ name="spitfight",
17
+ version="0.0.1",
18
+ url="https://github.com/ml-energy/leaderboard",
19
+ packages=find_packages("."),
20
+ extras_require=extras_require,
21
+ )
spitfight/__init__.py ADDED
File without changes
spitfight/colosseum/__init__.py ADDED
File without changes
spitfight/colosseum/client.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import unittest
5
+ import contextlib
6
+ from uuid import uuid4, UUID
7
+ from copy import deepcopy
8
+ from typing import Generator, Literal
9
+
10
+ import requests
11
+ import gradio as gr
12
+
13
+ from spitfight.colosseum.common import (
14
+ COLOSSEUM_PROMPT_ROUTE,
15
+ COLOSSEUM_RESP_VOTE_ROUTE,
16
+ COLOSSEUM_ENERGY_VOTE_ROUTE,
17
+ PromptRequest,
18
+ ResponseVoteRequest,
19
+ ResponseVoteResponse,
20
+ EnergyVoteRequest,
21
+ EnergyVoteResponse,
22
+ )
23
+
24
+
25
+ class ControllerClient:
26
+ """Client for the Colosseum controller, to be used by Gradio."""
27
+
28
+ def __init__(self, controller_addr: str, timeout: int = 15, request_id: UUID | None = None) -> None:
29
+ """Initialize the controller client."""
30
+ self.controller_addr = controller_addr
31
+ self.timeout = timeout
32
+ self.request_id = str(request_id) or str(uuid4())
33
+
34
+ def fork(self) -> ControllerClient:
35
+ """Return a copy of the client with a new request ID."""
36
+ return ControllerClient(
37
+ controller_addr=self.controller_addr,
38
+ timeout=self.timeout,
39
+ request_id=uuid4(),
40
+ )
41
+
42
+ def prompt(self, prompt: str, index: Literal[0, 1]) -> Generator[str, None, None]:
43
+ """Generate the response of the `index`th model with the prompt."""
44
+ prompt_request = PromptRequest(request_id=self.request_id, prompt=prompt, model_index=index)
45
+ with _catch_requests_exceptions():
46
+ resp = requests.post(
47
+ f"http://{self.controller_addr}{COLOSSEUM_PROMPT_ROUTE}",
48
+ json=prompt_request.dict(),
49
+ stream=True,
50
+ timeout=self.timeout,
51
+ )
52
+ _check_response(resp)
53
+ # XXX: Why can't the server just yield `text + "\n"` and here we just iter_lines?
54
+ for chunk in resp.iter_lines(decode_unicode=False, delimiter=b"\0"):
55
+ if chunk:
56
+ yield json.loads(chunk.decode("utf-8"))
57
+
58
+ def response_vote(self, victory_index: Literal[0, 1]) -> ResponseVoteResponse:
59
+ """Notify the controller of the user's vote for the response."""
60
+ response_vote_request = ResponseVoteRequest(request_id=self.request_id, victory_index=victory_index)
61
+ with _catch_requests_exceptions():
62
+ resp = requests.post(
63
+ f"http://{self.controller_addr}{COLOSSEUM_RESP_VOTE_ROUTE}",
64
+ json=response_vote_request.dict(),
65
+ )
66
+ _check_response(resp)
67
+ return ResponseVoteResponse(**resp.json())
68
+
69
+ def energy_vote(self, is_worth: bool) -> EnergyVoteResponse:
70
+ """Notify the controller of the user's vote for energy."""
71
+ energy_vote_request = EnergyVoteRequest(request_id=self.request_id, is_worth=is_worth)
72
+ with _catch_requests_exceptions():
73
+ resp = requests.post(
74
+ f"http://{self.controller_addr}{COLOSSEUM_ENERGY_VOTE_ROUTE}",
75
+ json=energy_vote_request.dict(),
76
+ )
77
+ _check_response(resp)
78
+ return EnergyVoteResponse(**resp.json())
79
+
80
+
81
+ @contextlib.contextmanager
82
+ def _catch_requests_exceptions():
83
+ """Catch requests exceptions and raise gr.Error instead."""
84
+ try:
85
+ yield
86
+ except (requests.exceptions.ConnectionError, requests.exceptions.Timeout):
87
+ raise gr.Error("Failed to connect to our the backend server. Please try again later.")
88
+
89
+
90
+ def _check_response(response: requests.Response) -> None:
91
+ if 400 <= response.status_code < 500:
92
+ raise gr.Error(response.json()["detail"])
93
+ elif response.status_code >= 500:
94
+ raise gr.Error("Failed to talk to our backend server. Please try again later.")
95
+
96
+
97
+ class TestControllerClient(unittest.TestCase):
98
+ def test_new_uuid_on_deepcopy(self):
99
+ client = ControllerClient("http://localhost:8000")
100
+ clients = [client.fork() for _ in range(50)]
101
+ request_ids = [client.request_id for client in clients]
102
+ assert len(set(request_ids)) == len(request_ids)
103
+
104
+
105
+ if __name__ == "__main__":
106
+ unittest.main()
spitfight/colosseum/common.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+
5
+ from pydantic import BaseModel
6
+
7
+ COLOSSEUM_PROMPT_ROUTE = "/prompt"
8
+ COLOSSEUM_RESP_VOTE_ROUTE = "/response_vote"
9
+ COLOSSEUM_ENERGY_VOTE_ROUTE = "/energy_vote"
10
+ COLOSSEUM_HEALTH_ROUTE = "/health"
11
+
12
+
13
+ class PromptRequest(BaseModel):
14
+ request_id: str
15
+ prompt: str
16
+ model_index: Literal[0, 1]
17
+
18
+
19
+ class ResponseVoteRequest(BaseModel):
20
+ request_id: str
21
+ victory_index: Literal[0, 1]
22
+
23
+
24
+ class ResponseVoteResponse(BaseModel):
25
+ model_names: list[str]
26
+ energy_consumptions: list[float]
27
+
28
+
29
+ class EnergyVoteRequest(BaseModel):
30
+ request_id: str
31
+ is_worth: bool
32
+
33
+
34
+ class EnergyVoteResponse(BaseModel):
35
+ model_names: list[str]
spitfight/colosseum/controller/__init__.py ADDED
File without changes
spitfight/colosseum/controller/controller.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import asyncio
5
+ from datetime import datetime
6
+ from typing import AsyncGenerator, Literal, Optional, TYPE_CHECKING
7
+
8
+ import aiohttp
9
+ from pytz import timezone
10
+ from pydantic import BaseModel, Field
11
+
12
+ from spitfight.log import get_logger
13
+ from spitfight.utils import BoundedExpiringDict, TokenGenerationBuffer, create_task
14
+ from spitfight.colosseum.controller.worker import WorkerService
15
+ from spitfight.prompt import get_system_prompt, apply_model_characteristics
16
+
17
+ if TYPE_CHECKING:
18
+ from spitfight.colosseum.controller.router import ControllerConfig
19
+
20
+ controller_logger = get_logger(__name__)
21
+ request_logger = get_logger("colosseum_requests")
22
+
23
+
24
+ def now() -> datetime:
25
+ return datetime.now(tz=timezone("US/Eastern"))
26
+
27
+
28
+ # Internal states
29
+ # The two "chose_*" stages are both the result of voting on a response.
30
+ # A normal user will sequentially go through either
31
+ # "prompted" -> "chose_less_energy_response", or
32
+ # "prompted" -> "chose_more_energy_response" -> "voted_energy"
33
+ UserStage = Literal[
34
+ "prompted",
35
+ "chose_less_energy_response",
36
+ "chose_more_energy_response",
37
+ "voted_energy",
38
+ ]
39
+
40
+
41
+ class RequestState(BaseModel):
42
+ """Models the state of a Colosseum play.
43
+
44
+ This model is also serialized as is and logged.
45
+ """
46
+ request_id: str
47
+ prompt: str
48
+ model_names: list[str]
49
+ responses: list[str] = ["EMPTY", "EMPTY"]
50
+ energy_consumptions: list[float] = [0.0, 0.0]
51
+ response_victory_index: Optional[Literal[0, 1]] = None
52
+ extra_energy_was_worth: Optional[bool] = None
53
+
54
+ # The time when the user's stage changed.
55
+ timestamp: datetime = Field(default_factory=now)
56
+ # The user's current stage.
57
+ user_stage: UserStage = "prompted"
58
+ # When the the user is not going through the aforementioned stages,
59
+ # the user's stage transition is recorded here.
60
+ abnormal_stage_change: list[tuple[UserStage, UserStage]] = []
61
+
62
+ def set_response_and_energy(self, model_index: Literal[0, 1], response: str, energy_consumption: float) -> None:
63
+ self.timestamp = now()
64
+ self.energy_consumptions[model_index] = energy_consumption
65
+ self.responses[model_index] = response
66
+
67
+ def set_response_vote(self, victory_index: Literal[0, 1]) -> None:
68
+ self.timestamp = now()
69
+
70
+ # Next stage depends on the user's vote.
71
+ energy_a, energy_b = self.energy_consumptions
72
+ if (victory_index == 0 and energy_a <= energy_b) or (victory_index == 1 and energy_a >= energy_b):
73
+ next_stage = "chose_less_energy_response"
74
+ else:
75
+ next_stage = "chose_more_energy_response"
76
+
77
+ # Detect abnormal stage change.
78
+ if self.user_stage != "prompted":
79
+ self.abnormal_stage_change.append((self.user_stage, next_stage))
80
+
81
+ self.user_stage = next_stage
82
+ self.response_victory_index = victory_index
83
+
84
+ def set_energy_vote(self, is_worth: bool) -> None:
85
+ self.timestamp = now()
86
+
87
+ # Detect abnormal stage change.
88
+ if self.user_stage != "chose_more_energy_response":
89
+ self.abnormal_stage_change.append((self.user_stage, "voted_energy"))
90
+
91
+ self.user_stage = "voted_energy"
92
+ self.extra_energy_was_worth = is_worth
93
+
94
+
95
+ class GenerationConfig(BaseModel):
96
+ """Configuration for generation of prompts."""
97
+ max_new_tokens: int
98
+ do_sample: bool
99
+ temperature: float
100
+ repetition_penalty: float
101
+ top_k: int
102
+ top_p: float
103
+
104
+
105
+ class Controller:
106
+ def __init__(
107
+ self,
108
+ background_task_interval: int,
109
+ max_num_req_states: int,
110
+ req_state_expiration_time: int,
111
+ worker_service: WorkerService,
112
+ generation_config: GenerationConfig,
113
+ ):
114
+ self.request_states: BoundedExpiringDict[str, RequestState] = \
115
+ BoundedExpiringDict(max_num_req_states, req_state_expiration_time)
116
+ self.worker_service = worker_service
117
+
118
+ self.generation_config = generation_config
119
+
120
+ self.background_task_handle = create_task(
121
+ self._background_task(background_task_interval),
122
+ )
123
+
124
+ def shutdown(self) -> None:
125
+ """Shutdown the controller."""
126
+ self.background_task_handle.cancel()
127
+
128
+ async def _background_task(self, heartbeat_interval: int) -> None:
129
+ """Periodically check if dead workers are alive again and do request state GC."""
130
+ while True:
131
+ await asyncio.sleep(heartbeat_interval)
132
+
133
+ await self.worker_service.check_workers()
134
+
135
+ prev_num_req_states = len(self.request_states)
136
+ self.request_states.cleanup()
137
+ controller_logger.info(
138
+ "Request state garbage collection done: Removed %d reqeusts",
139
+ prev_num_req_states - len(self.request_states),
140
+ )
141
+
142
+ def response_vote(self, request_id: str, victory_index: Literal[0, 1]) -> RequestState | None:
143
+ """Record the user's response vote and return the new state."""
144
+ if (state := self.request_states.get(request_id)) is not None:
145
+ state.set_response_vote(victory_index)
146
+ # Pop the state from the dict if the user has voted on energy.
147
+ if state.user_stage == "chose_less_energy_response":
148
+ self.request_states.pop(request_id)
149
+ request_logger.info(state.json())
150
+ return state
151
+ return None
152
+
153
+ def energy_vote(self, request_id: str, is_worth: bool) -> RequestState | None:
154
+ """Record the user's energy vote and return the new state."""
155
+ # Pop the state from the dict, since this is the last step in any case.
156
+ if (state := self.request_states.pop(request_id)) is not None:
157
+ state.set_energy_vote(is_worth)
158
+ request_logger.info(state.json())
159
+ return state
160
+ return None
161
+
162
+ async def prompt(
163
+ self,
164
+ request_id: str,
165
+ prompt: str,
166
+ model_index: Literal[0, 1],
167
+ ) -> AsyncGenerator[bytes, None]:
168
+ # This method is called twice for the same request, once for each model.
169
+ # If it's the first time this method is called, assign models to the request.
170
+ if request_id not in self.request_states:
171
+ workers = self.worker_service.choose_two()
172
+ model_names = [worker.model_name for worker in workers]
173
+ self.request_states[request_id] = RequestState(
174
+ request_id=request_id,
175
+ prompt=prompt,
176
+ model_names=model_names,
177
+ )
178
+ request_state = self.request_states[request_id]
179
+ model_name = request_state.model_names[model_index]
180
+ try:
181
+ worker = self.worker_service.get_worker(model_name)
182
+ except KeyError:
183
+ controller_logger.error("Worker %s not found.", model_name)
184
+ raise
185
+ except RuntimeError:
186
+ controller_logger.error("Worker %s is dead.", model_name)
187
+ raise
188
+ prompt, stop_str, stop_token_ids = apply_model_characteristics(
189
+ system_prompt=get_system_prompt("chat"),
190
+ prompt=prompt,
191
+ model_name=worker.model_id,
192
+ )
193
+
194
+ # Request the model worker to stream the response to the user's prompt.
195
+ response = ""
196
+ energy = 0.0
197
+ client = worker.get_client()
198
+ buffer = TokenGenerationBuffer(stop_str=stop_str)
199
+ try:
200
+ async for resp in client.generate_stream(
201
+ prompt=prompt,
202
+ stop_sequences=[stop_str] if stop_str is not None else None,
203
+ **self.generation_config.dict(),
204
+ ):
205
+ # Even special tokens consume energy when they're generated.
206
+ energy += resp.token.energy
207
+
208
+ # Stop tokens usually don't overlap with (human-readable) stop sequences.
209
+ # if resp.token.special or resp.token.id in stop_token_ids:
210
+ if resp.token.id in stop_token_ids:
211
+ # If the buffer is not empty (i.e., we had partial stop_str matches),
212
+ # just yield it to the user.
213
+ if (chunk := buffer.token_buffer):
214
+ response += chunk
215
+ yield json.dumps(chunk).encode() + b"\0"
216
+ break
217
+
218
+ # Skip special tokens.
219
+ if resp.token.special:
220
+ continue
221
+
222
+ # The buffer automatically handles `stop_str` partial and full matches.
223
+ buffer.append(resp.token.text)
224
+ if (chunk := buffer.pop()) is not None:
225
+ response += chunk
226
+ yield json.dumps(chunk).encode() + b"\0"
227
+ elif buffer.matched_stop_str:
228
+ break
229
+ except aiohttp.ClientConnectorError:
230
+ worker.status = "down"
231
+ controller_logger.error(
232
+ "Problem talking to %s. Aborting and setting worker status to down",
233
+ repr(worker),
234
+ )
235
+ raise
236
+ except Exception:
237
+ yield json.dumps(buffer.token_buffer).encode() + b"\0"
238
+ raise
239
+ finally:
240
+ request_state.set_response_and_energy(model_index, response, energy)
241
+ request_logger.info(request_state.json())
242
+
243
+
244
+ CONTROLLER: Controller | None = None
245
+
246
+ def init_global_controller(config: ControllerConfig) -> None:
247
+ global CONTROLLER
248
+ CONTROLLER = Controller(
249
+ background_task_interval=config.background_task_interval,
250
+ max_num_req_states=config.max_num_req_states,
251
+ req_state_expiration_time=config.req_state_expiration_time,
252
+ worker_service=WorkerService(config.compose_files),
253
+ generation_config=GenerationConfig(
254
+ max_new_tokens=config.max_new_tokens,
255
+ do_sample=config.do_sample,
256
+ temperature=config.temperature,
257
+ repetition_penalty=config.repetition_penalty,
258
+ top_k=config.top_k,
259
+ top_p=config.top_p,
260
+ ),
261
+ )
262
+
263
+ def get_global_controller() -> Controller:
264
+ global CONTROLLER
265
+ assert CONTROLLER is not None
266
+ return CONTROLLER
spitfight/colosseum/controller/router.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ import uvicorn
5
+ from pydantic import BaseSettings
6
+ from fastapi import FastAPI, Depends
7
+ from fastapi.responses import StreamingResponse
8
+ from fastapi.exceptions import HTTPException
9
+ from text_generation.errors import OverloadedError, UnknownError, ValidationError
10
+
11
+ from spitfight.log import get_logger, init_queued_root_logger, shutdown_queued_root_loggers
12
+ from spitfight.colosseum.common import (
13
+ COLOSSEUM_PROMPT_ROUTE,
14
+ COLOSSEUM_RESP_VOTE_ROUTE,
15
+ COLOSSEUM_ENERGY_VOTE_ROUTE,
16
+ COLOSSEUM_HEALTH_ROUTE,
17
+ PromptRequest,
18
+ ResponseVoteRequest,
19
+ ResponseVoteResponse,
20
+ EnergyVoteRequest,
21
+ EnergyVoteResponse,
22
+ )
23
+ from spitfight.colosseum.controller.controller import (
24
+ Controller,
25
+ init_global_controller,
26
+ get_global_controller,
27
+ )
28
+ from spitfight.utils import prepend_generator
29
+
30
+
31
+ class ControllerConfig(BaseSettings):
32
+ """Controller settings automatically loaded from environment variables."""
33
+ # Controller
34
+ background_task_interval: int = 300
35
+ max_num_req_states: int = 10000
36
+ req_state_expiration_time: int = 600
37
+ compose_files: list[str] = ["deployment/docker-compose-0.yaml", "deployment/docker-compose-1.yaml"]
38
+
39
+ # Logging
40
+ log_dir: str = "/logs"
41
+ controller_log_file: str = "controller.log"
42
+ request_log_file: str = "requests.log"
43
+ uvicorn_log_file: str = "uvicorn.log"
44
+
45
+ # Generation
46
+ max_new_tokens: int = 512
47
+ do_sample: bool = True
48
+ temperature: float = 1.0
49
+ repetition_penalty: float = 1.0
50
+ top_k: int = 50
51
+ top_p: float = 0.95
52
+
53
+
54
+ app = FastAPI()
55
+ settings = ControllerConfig()
56
+ logger = get_logger("spitfight.colosseum.controller.router")
57
+
58
+ @app.on_event("startup")
59
+ async def startup_event():
60
+ init_queued_root_logger("uvicorn", os.path.join(settings.log_dir, settings.uvicorn_log_file))
61
+ init_queued_root_logger("spitfight.colosseum.controller", os.path.join(settings.log_dir, settings.controller_log_file))
62
+ init_queued_root_logger("colosseum_requests", os.path.join(settings.log_dir, settings.request_log_file))
63
+ init_global_controller(settings)
64
+
65
+ @app.on_event("shutdown")
66
+ async def shutdown_event():
67
+ get_global_controller().shutdown()
68
+ shutdown_queued_root_loggers()
69
+
70
+ @app.post(COLOSSEUM_PROMPT_ROUTE)
71
+ async def prompt(
72
+ request: PromptRequest,
73
+ controller: Controller = Depends(get_global_controller),
74
+ ):
75
+ generator = controller.prompt(request.request_id, request.prompt, request.model_index)
76
+
77
+ # First try to get the first token in order to catch TGI errors.
78
+ try:
79
+ first_token = await generator.__anext__()
80
+ except OverloadedError:
81
+ name = controller.request_states[request.request_id].model_names[request.model_index]
82
+ logger.warning("Model %s is overloaded. Failed request: %s", name, repr(request))
83
+ raise HTTPException(status_code=429, detail="Model overloaded. Pleaes try again later.")
84
+ except ValidationError as e:
85
+ logger.info("TGI returned validation error: %s. Failed request: %s", str(e), repr(request))
86
+ raise HTTPException(status_code=422, detail=str(e))
87
+ except StopAsyncIteration:
88
+ logger.info("TGI returned empty response. Failed request: %s", repr(request))
89
+ return StreamingResponse(
90
+ iter([json.dumps("*The model generated an empty response.*").encode() + b"\0"]),
91
+ )
92
+ except UnknownError as e:
93
+ logger.error("TGI returned unknown error: %s. Failed request: %s", str(e), repr(request))
94
+ raise HTTPException(status_code=500, detail=str(e))
95
+
96
+ return StreamingResponse(prepend_generator(first_token, generator))
97
+
98
+ @app.post(COLOSSEUM_RESP_VOTE_ROUTE, response_model=ResponseVoteResponse)
99
+ async def response_vote(
100
+ request: ResponseVoteRequest,
101
+ controller: Controller = Depends(get_global_controller),
102
+ ):
103
+ if (state := controller.response_vote(request.request_id, request.victory_index)) is None:
104
+ raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
105
+ return ResponseVoteResponse(
106
+ energy_consumptions=state.energy_consumptions,
107
+ model_names=state.model_names,
108
+ )
109
+
110
+ @app.post(COLOSSEUM_ENERGY_VOTE_ROUTE, response_model=EnergyVoteResponse)
111
+ async def energy_vote(
112
+ request: EnergyVoteRequest,
113
+ controller: Controller = Depends(get_global_controller),
114
+ ):
115
+ if (state := controller.energy_vote(request.request_id, request.is_worth)) is None:
116
+ raise HTTPException(status_code=410, detail="Colosseum battle session timeout expired.")
117
+ return EnergyVoteResponse(model_names=state.model_names)
118
+
119
+ @app.get(COLOSSEUM_HEALTH_ROUTE)
120
+ async def health():
121
+ return "OK"
122
+
123
+
124
+ if __name__ == "__main__":
125
+ uvicorn.run(app, host="0.0.0.0", log_config=None)
spitfight/colosseum/controller/worker.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import random
3
+ import asyncio
4
+ from typing import Literal
5
+ from functools import cached_property
6
+
7
+ import httpx
8
+ from pydantic import BaseModel
9
+ from text_generation import AsyncClient
10
+
11
+ from spitfight.log import get_logger
12
+
13
+ logger = get_logger(__name__)
14
+
15
+
16
+ class Worker(BaseModel):
17
+ """A worker that serves a model."""
18
+ # Worker's container name, since we're using Overlay networks.
19
+ hostname: str
20
+ # For TGI, this would always be 80.
21
+ port: int
22
+ # User-friendly model name, e.g. "metaai/llama2-13b-chat".
23
+ model_name: str
24
+ # Hugging Face model ID, e.g. "metaai/Llama-2-13b-chat-hf".
25
+ model_id: str
26
+ # Whether the model worker container is good.
27
+ status: Literal["up", "down"]
28
+
29
+ class Config:
30
+ keep_untouched = (cached_property,)
31
+
32
+ @cached_property
33
+ def url(self) -> str:
34
+ return f"http://{self.hostname}:{self.port}"
35
+
36
+ def get_client(self) -> AsyncClient:
37
+ return AsyncClient(base_url=self.url)
38
+
39
+ def audit(self) -> None:
40
+ """Make sure the worker is running and information is as expected.
41
+
42
+ Assumed to be called on app startup when workers are initialized.
43
+ This method will just raise `ValueError`s if audit fails in order to
44
+ prevent the controller from starting if anything is wrong.
45
+ """
46
+ try:
47
+ response = httpx.get(self.url + "/info")
48
+ except (httpx.ConnectError, httpx.TimeoutException) as e:
49
+ raise ValueError(f"Could not connect to {self!r}: {e!r}")
50
+ if response.status_code != 200:
51
+ raise ValueError(f"Could not get /info from {self!r}.")
52
+ info = response.json()
53
+ if info["model_id"] != self.model_id:
54
+ raise ValueError(f"Model name mismatch: {info['model_id']} != {self.model_id}")
55
+ self.status = "up"
56
+ logger.info("%s is up.", repr(self))
57
+
58
+ async def check_status(self) -> None:
59
+ """Check worker status and update `self.status` accordingly."""
60
+ async with httpx.AsyncClient() as client:
61
+ try:
62
+ response = await client.get(self.url + "/info")
63
+ except (httpx.ConnectError, httpx.TimeoutException) as e:
64
+ self.status = "down"
65
+ logger.warning("%s is down: %s", repr(self), repr(e))
66
+ return
67
+ if response.status_code != 200:
68
+ self.status = "down"
69
+ logger.warning("GET /info from %s returned %s.", repr(self), response.json())
70
+ return
71
+ info = response.json()
72
+ if info["model_id"] != self.model_id:
73
+ self.status = "down"
74
+ logger.warning(
75
+ "Model name mismatch for %s: %s != %s",
76
+ repr(self),
77
+ info["model_id"],
78
+ self.model_id,
79
+ )
80
+ return
81
+ logger.info("%s is up.", repr(self))
82
+ self.status = "up"
83
+
84
+
85
+ class WorkerService:
86
+ """A service that manages model serving workers.
87
+
88
+ Worker objects are only created once and shared across the
89
+ entire application. Especially, changing the status of a worker
90
+ will immediately take effect on the result of `choose_two`.
91
+
92
+ Attributes:
93
+ workers (list[Worker]): The list of workers.
94
+ """
95
+
96
+ def __init__(self, compose_files: list[str]) -> None:
97
+ """Initialize the worker service."""
98
+ self.workers: list[Worker] = []
99
+ worker_model_names = set()
100
+ for compose_file in compose_files:
101
+ spec = yaml.safe_load(open(compose_file))
102
+ for model_name, service_spec in spec["services"].items():
103
+ command = service_spec["command"]
104
+ for i, cmd in enumerate(command):
105
+ if cmd == "--model-id":
106
+ model_id = command[i + 1]
107
+ break
108
+ else:
109
+ raise ValueError(f"Could not find model ID in {command!r}")
110
+ worker_model_names.add(model_name)
111
+ worker = Worker(
112
+ hostname=service_spec["container_name"],
113
+ port=80,
114
+ model_name=model_name,
115
+ model_id=model_id,
116
+ status="down",
117
+ )
118
+ worker.audit()
119
+ self.workers.append(worker)
120
+
121
+ if len(worker_model_names) != len(self.workers):
122
+ raise ValueError("Model names must be unique.")
123
+
124
+ def get_worker(self, model_name: str) -> Worker:
125
+ """Get a worker by model name."""
126
+ for worker in self.workers:
127
+ if worker.model_name == model_name:
128
+ if worker.status == "down":
129
+ # This is an unfortunate case where, when the two models were chosen,
130
+ # the worker was up, but after that went down before the request
131
+ # completed. We'll just raise a 500 internal error and have the user
132
+ # try again. This won't be common.
133
+ raise RuntimeError(f"The worker with model name {model_name} is down.")
134
+ return worker
135
+ raise ValueError(f"Worker with model name {model_name} does not exist.")
136
+
137
+ def choose_two(self) -> tuple[Worker, Worker]:
138
+ """Choose two different workers.
139
+
140
+ Good place to use the Strategy Pattern when we want to
141
+ implement different strategies for choosing workers.
142
+ """
143
+ live_workers = [worker for worker in self.workers if worker.status == "up"]
144
+ if len(live_workers) < 2:
145
+ raise ValueError("Not enough live workers to choose from.")
146
+ worker_a, worker_b = random.sample(live_workers, 2)
147
+ return worker_a, worker_b
148
+
149
+ async def check_workers(self) -> None:
150
+ """Check the status of all workers."""
151
+ await asyncio.gather(*[worker.check_status() for worker in self.workers])
spitfight/log.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import queue
4
+ import logging
5
+ from logging.handlers import QueueHandler, QueueListener
6
+
7
+ ROOT_LOGGER_NAMES: list[str | None] = []
8
+ ROOT_LOGGER_QUEUE_LISTENERS: list[QueueListener] = []
9
+
10
+
11
+ def init_queued_root_logger(
12
+ name: str | None,
13
+ filepath: str,
14
+ level: int = logging.INFO,
15
+ ) -> None:
16
+ """Initialize a queue-based pseudo-root logger.
17
+
18
+ The pseudo-root logger will aggregate log messages from children
19
+ loggers under its namespace and send them to a queue. A QueueListener,
20
+ running in a separate thread, will then process the messages in the
21
+ queue and send them to the configured handlers.
22
+ """
23
+ global ROOT_LOGGER_NAMES, ROOT_LOGGER_QUEUE_LISTENERS
24
+
25
+ # Make this function idempotent.
26
+ if name in ROOT_LOGGER_NAMES:
27
+ return
28
+
29
+ logger = logging.getLogger(name)
30
+ logger.setLevel(level)
31
+ logger.propagate = False
32
+
33
+ shared_queue = queue.SimpleQueue()
34
+ queue_handler = QueueHandler(shared_queue)
35
+ logger.addHandler(queue_handler)
36
+
37
+ formatter = logging.Formatter(
38
+ "[%(asctime)s] [%(levelname)s] [%(name)s](%(filename)s:%(lineno)d) %(message)s"
39
+ )
40
+
41
+ stderr_handler = logging.StreamHandler()
42
+ stderr_handler.setLevel(level)
43
+ stderr_handler.setFormatter(formatter)
44
+
45
+ file_handler = logging.FileHandler(filepath, encoding="utf-8")
46
+ file_handler.setLevel(level)
47
+ file_handler.setFormatter(formatter)
48
+
49
+ queue_listener = QueueListener(shared_queue, file_handler, stderr_handler)
50
+ queue_listener.start()
51
+
52
+ ROOT_LOGGER_NAMES.append(name)
53
+ ROOT_LOGGER_QUEUE_LISTENERS.append(queue_listener)
54
+
55
+
56
+ def shutdown_queued_root_loggers() -> None:
57
+ """Shutdown all queue-based pseudo-root loggers.
58
+
59
+ This is necessary to make sure all log messages are flushed
60
+ before the application exits.
61
+ """
62
+ for queue_listener in ROOT_LOGGER_QUEUE_LISTENERS:
63
+ queue_listener.stop()
64
+
65
+
66
+ def get_logger(name: str, level: int = logging.INFO) -> logging.Logger:
67
+ """Setup a logger with the given name and level."""
68
+ # Don't reconfigure existing loggers.
69
+ if name in logging.Logger.manager.loggerDict:
70
+ return logging.getLogger(name)
71
+
72
+ logger = logging.getLogger(name)
73
+ logger.setLevel(level)
74
+ logger.propagate = True
75
+
76
+ return logger
spitfight/prompt.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """An abstraction layer for prompting different models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ from fastchat.model.model_adapter import get_conversation_template
8
+
9
+
10
+ class Task(enum.Enum):
11
+ """Different system prompt styles."""
12
+
13
+ CHAT = "chat"
14
+ CHAT_CONCISE = "chat-concise"
15
+ INSTRUCT = "instruct"
16
+ INSTRUCT_CONCISE = "instruct-concise"
17
+
18
+
19
+ SYSTEM_PROMPTS = {
20
+ Task.CHAT: (
21
+ "A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
22
+ "The assistant gives helpful, detailed, and polite answers to the user's questions. "
23
+ ),
24
+ Task.CHAT_CONCISE: (
25
+ "A chat between a human user (prompter) and an artificial intelligence (AI) assistant. "
26
+ "The assistant gives helpful, detailed, and polite answers to the user's questions. "
27
+ "The assistant's answers are very concise. "
28
+ ),
29
+ Task.INSTRUCT: (
30
+ "Below is an instruction that describes a task. "
31
+ "Write a response that appropriately completes the request. "
32
+ ),
33
+ Task.INSTRUCT_CONCISE: (
34
+ "Below is an instruction that describes a task. "
35
+ "Write a response that appropriately completes the request. "
36
+ "The response should be very concise. "
37
+ ),
38
+ }
39
+
40
+ def get_system_prompt(task: Task | str) -> str:
41
+ """Get the system prompt for a given task."""
42
+ if isinstance(task, str):
43
+ task = Task(task)
44
+ return SYSTEM_PROMPTS[task]
45
+
46
+
47
+ def apply_model_characteristics(
48
+ system_prompt: str,
49
+ prompt: str,
50
+ model_name: str,
51
+ ) -> tuple[str, str | None, list[int]]:
52
+ """Apply and return model-specific differences."""
53
+ conv = get_conversation_template(model_name)
54
+
55
+ if "llama-2" in model_name.lower():
56
+ conv.system = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n"
57
+ elif "stablelm" in model_name.lower():
58
+ conv.system = f"""<|SYSTEM|># {system_prompt}\n"""
59
+ else:
60
+ conv.system = system_prompt
61
+ conv.messages = []
62
+ conv.offset = 0
63
+
64
+ conv.append_message(conv.roles[0], prompt)
65
+ conv.append_message(conv.roles[1], "")
66
+
67
+ stop_str = None if conv.stop_str is None or not conv.stop_str else conv.stop_str
68
+
69
+ return conv.get_prompt(), stop_str, (conv.stop_token_ids or [])
spitfight/utils.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import time
4
+ import heapq
5
+ import asyncio
6
+ import unittest
7
+ from typing import TypeVar, Generic, AsyncGenerator, Any, Coroutine
8
+
9
+ from fastapi.logger import logger
10
+
11
+ K = TypeVar('K')
12
+ V = TypeVar('V')
13
+
14
+
15
+ class BoundedExpiringDict(Generic[K, V]):
16
+ def __init__(self, max_size: int, expiration_time: int) -> None:
17
+ self.data_dict: dict[K, V] = {}
18
+ self.timestamp_heap: list[tuple[float, K]] = []
19
+ self.timeout = expiration_time
20
+
21
+ # Without this, the controller is vulnerable to "user flood attacks,"
22
+ # where someone can create a bunch of users by polling /request before
23
+ # self.timeout expires and blow up memory.
24
+ self.max_size = max_size
25
+
26
+ def __getitem__(self, key: K) -> V:
27
+ return self.data_dict[key]
28
+
29
+ def __setitem__(self, key: K, value: V) -> None:
30
+ if len(self.data_dict) >= self.max_size:
31
+ self.cleanup()
32
+
33
+ heapq.heappush(self.timestamp_heap, (time.monotonic(), key))
34
+ self.data_dict[key] = value
35
+
36
+ def __delitem__(self, key: K) -> None:
37
+ # This is a bit inefficient, but it's not a common case operation.
38
+ # We still need to do this to keep timestamp_heap in sync.
39
+ del self.data_dict[key]
40
+ for i, (_, existing_key) in enumerate(self.timestamp_heap):
41
+ if existing_key == key:
42
+ del self.timestamp_heap[i]
43
+ break
44
+ heapq.heapify(self.timestamp_heap)
45
+
46
+ def __contains__(self, key: K) -> bool:
47
+ return key in self.data_dict
48
+
49
+ def __len__(self) -> int:
50
+ return len(self.data_dict)
51
+
52
+ def get(self, key: K, default: V | None = None) -> V | None:
53
+ return self.data_dict.get(key, default)
54
+
55
+ def pop(self, key: K, default: V | None = None) -> V | None:
56
+ item = self.data_dict.pop(key, default)
57
+ if item is not None:
58
+ for i, (_, existing_key) in enumerate(self.timestamp_heap):
59
+ if existing_key == key:
60
+ del self.timestamp_heap[i]
61
+ break
62
+ heapq.heapify(self.timestamp_heap)
63
+ return item
64
+
65
+ def cleanup(self) -> None:
66
+ now = time.monotonic()
67
+ # After the while loop, the dictionary will be smaller than max_size
68
+ # and all keys will have been accessed within the timeout.
69
+ while (self.timestamp_heap and now - self.timestamp_heap[0][0] > self.timeout) or len(self.data_dict) > self.max_size:
70
+ _, key = heapq.heappop(self.timestamp_heap)
71
+ del self.data_dict[key]
72
+
73
+ assert len(self.data_dict) == len(self.timestamp_heap)
74
+
75
+
76
+ T = TypeVar("T")
77
+
78
+
79
+ async def prepend_generator(
80
+ first_item: T,
81
+ generator: AsyncGenerator[T, None],
82
+ ) -> AsyncGenerator[T, None]:
83
+ """Prepend an item to an async generator."""
84
+ yield first_item
85
+ async for item in generator:
86
+ yield item
87
+
88
+
89
+ def create_task(coroutine: Coroutine[Any, Any, T]) -> asyncio.Task[T]:
90
+ """Create an `asyncio.Task` but ensure that exceptions are logged.
91
+
92
+ Reference: https://quantlane.com/blog/ensure-asyncio-task-exceptions-get-logged/
93
+ """
94
+ loop = asyncio.get_running_loop()
95
+ task = loop.create_task(coroutine)
96
+ task.add_done_callback(_handle_task_exception)
97
+ return task
98
+
99
+
100
+ def _handle_task_exception(task: asyncio.Task) -> None:
101
+ """Print out exception and tracebook when a task dies with an exception."""
102
+ try:
103
+ task.result()
104
+ except asyncio.CancelledError:
105
+ # Cancellation should not be logged as an error.
106
+ pass
107
+ except Exception: # pylint: disable=broad-except
108
+ # `logger.exception` automatically handles exception and traceback info.
109
+ logger.exception("Job task died with an exception!")
110
+
111
+
112
+ class TokenGenerationBuffer:
113
+ """A constant sized buffer for tokens, used to handle stop sequences.
114
+
115
+ Attributes:
116
+ token_buffer (str): Internal buffer for tokens.
117
+ matched_stop_str (bool): Whether the stop string has been seen. When this
118
+ is True, generation should stop and `pop` will always return None.
119
+ """
120
+ def __init__(self, stop_str: str | None = None) -> None:
121
+ """Initialize the buffer.
122
+
123
+ If `stop_str` is None, the buffer will just return all tokens as they come.
124
+ """
125
+ self.stop_str = stop_str
126
+ self.token_len_list = []
127
+ self.token_buffer = ""
128
+ self.matched_stop_str = False
129
+
130
+ def append(self, text: str) -> None:
131
+ """Append a token to the buffer."""
132
+ if self.stop_str is not None:
133
+ self.token_len_list.append(len(text))
134
+ self.token_buffer += text
135
+
136
+ def _pop_one(self) -> str:
137
+ """Remove and return the first token in the buffer."""
138
+ token_len = self.token_len_list.pop(0)
139
+ token, self.token_buffer = self.token_buffer[:token_len], self.token_buffer[token_len:]
140
+ return token
141
+
142
+ def pop(self) -> str | None:
143
+ """Try to pop a token from the buffer.
144
+
145
+ Return value None means that there is nothing to yield for now.
146
+ Repeated calls to this method will always just return None before more
147
+ tokens are appended to the buffer.
148
+ """
149
+ # A short circuit for no stop string.
150
+ if self.stop_str is None:
151
+ return_buffer = self.token_buffer or None
152
+ self.token_buffer = ""
153
+ return return_buffer
154
+
155
+ if self.matched_stop_str:
156
+ return None
157
+
158
+ # The token buffer matched the stop string. We're done generating.
159
+ if self.stop_str == self.token_buffer:
160
+ self.matched_stop_str = True
161
+ return None
162
+
163
+ # The tokens in the buffer could potentially be part of the stop string.
164
+ # We'll stay put until we see more tokens. This also covers the case of
165
+ # empty token buffer.
166
+ if self.stop_str.startswith(self.token_buffer):
167
+ return None
168
+
169
+ # We can return tokens from the beginning of the buffer until the buffer
170
+ # is a prefix of the stop string.
171
+ return_buffer = ""
172
+ while self.token_buffer:
173
+ return_buffer += self._pop_one()
174
+ if self.stop_str == self.token_buffer:
175
+ self.matched_stop_str = True
176
+ break
177
+ if self.stop_str.startswith(self.token_buffer):
178
+ break
179
+
180
+ return return_buffer or None
181
+
182
+
183
+
184
+ class TestTokenGenerationBuffer(unittest.TestCase):
185
+ def test_basic1(self):
186
+ buffer = TokenGenerationBuffer(stop_str="stop")
187
+
188
+ buffer.append("hello")
189
+ self.assertEqual(buffer.pop(), "hello")
190
+ self.assertEqual(buffer.pop(), None)
191
+ self.assertFalse(buffer.matched_stop_str)
192
+
193
+ buffer.append("world")
194
+ self.assertEqual(buffer.pop(), "world")
195
+ self.assertFalse(buffer.matched_stop_str)
196
+
197
+ buffer.append("stop")
198
+ self.assertEqual(buffer.pop(), None)
199
+ self.assertTrue(buffer.matched_stop_str)
200
+ self.assertEqual(buffer.pop(), None)
201
+ self.assertTrue(buffer.matched_stop_str)
202
+ self.assertEqual(buffer.pop(), None)
203
+ self.assertTrue(buffer.matched_stop_str)
204
+ self.assertEqual(buffer.pop(), None)
205
+ self.assertTrue(buffer.matched_stop_str)
206
+
207
+ def test_basic2(self):
208
+ buffer = TokenGenerationBuffer(stop_str="stop")
209
+
210
+ buffer.append("hi")
211
+ self.assertEqual(buffer.pop(), "hi")
212
+ self.assertFalse(buffer.matched_stop_str)
213
+
214
+ buffer.append("stole")
215
+ self.assertEqual(buffer.pop(), "stole")
216
+ self.assertFalse(buffer.matched_stop_str)
217
+
218
+ buffer.append("sto")
219
+ self.assertEqual(buffer.pop(), None)
220
+ self.assertFalse(buffer.matched_stop_str)
221
+
222
+ buffer.append("ic")
223
+ self.assertEqual(buffer.pop(), "stoic")
224
+ self.assertFalse(buffer.matched_stop_str)
225
+
226
+ buffer.append("st")
227
+ self.assertEqual(buffer.pop(), None)
228
+ self.assertFalse(buffer.matched_stop_str)
229
+
230
+ buffer.append("opper")
231
+ self.assertEqual(buffer.pop(), "stopper")
232
+ self.assertFalse(buffer.matched_stop_str)
233
+
234
+ buffer.append("sto")
235
+ self.assertEqual(buffer.pop(), None)
236
+ self.assertFalse(buffer.matched_stop_str)
237
+
238
+ buffer.append("p")
239
+ self.assertEqual(buffer.pop(), None)
240
+ self.assertTrue(buffer.matched_stop_str)
241
+
242
+ def test_falcon1(self):
243
+ buffer = TokenGenerationBuffer(stop_str="\nUser")
244
+
245
+ buffer.append("Hi")
246
+ self.assertEqual(buffer.pop(), "Hi")
247
+ self.assertFalse(buffer.matched_stop_str)
248
+
249
+ buffer.append("!")
250
+ self.assertEqual(buffer.pop(), "!")
251
+ self.assertFalse(buffer.matched_stop_str)
252
+
253
+ buffer.append("\n")
254
+ self.assertEqual(buffer.pop(), None)
255
+ self.assertFalse(buffer.matched_stop_str)
256
+
257
+ buffer.append("User")
258
+ self.assertEqual(buffer.pop(), None)
259
+ self.assertTrue(buffer.matched_stop_str)
260
+
261
+ def test_falcon2(self):
262
+ buffer = TokenGenerationBuffer(stop_str="\nUser")
263
+
264
+ buffer.append("\n")
265
+ self.assertEqual(buffer.pop(), None)
266
+ self.assertFalse(buffer.matched_stop_str)
267
+
268
+ buffer.append("\n")
269
+ self.assertEqual(buffer.pop(), "\n")
270
+ self.assertFalse(buffer.matched_stop_str)
271
+
272
+ buffer.append("\n")
273
+ self.assertEqual(buffer.pop(), "\n")
274
+ self.assertFalse(buffer.matched_stop_str)
275
+
276
+ buffer.append("\n")
277
+ self.assertEqual(buffer.pop(), "\n")
278
+ self.assertFalse(buffer.matched_stop_str)
279
+
280
+ buffer.append("User")
281
+ self.assertEqual(buffer.pop(), None)
282
+ self.assertEqual(buffer.pop(), None)
283
+ self.assertTrue(buffer.matched_stop_str)
284
+
285
+ def test_no_stop_str(self):
286
+ buffer = TokenGenerationBuffer(stop_str=None)
287
+
288
+ buffer.append("hello")
289
+ self.assertEqual(buffer.pop(), "hello")
290
+ self.assertEqual(buffer.pop(), None)
291
+ self.assertFalse(buffer.matched_stop_str)
292
+
293
+ buffer.append("world")
294
+ self.assertEqual(buffer.pop(), "world")
295
+ self.assertEqual(buffer.pop(), None)
296
+ self.assertFalse(buffer.matched_stop_str)
297
+
298
+ buffer.append("\n")
299
+ self.assertEqual(buffer.pop(), "\n")
300
+ self.assertEqual(buffer.pop(), None)
301
+ self.assertFalse(buffer.matched_stop_str)
302
+
303
+
304
+ if __name__ == "__main__":
305
+ unittest.main()