Spaces:
Runtime error
Runtime error
| import os | |
| import asyncio | |
| import random | |
| import sqlite3 | |
| import panel as pn | |
| import pandas as pd | |
| from litellm import acompletion | |
| pn.extension("perspective") | |
| MODELS = [ | |
| "mistral/mistral-tiny", | |
| "mistral/mistral-small", | |
| "mistral/mistral-medium", | |
| "mistral/mistral-large-latest", | |
| ] | |
| VOTING_LABELS = [ | |
| "π A is better", | |
| "π€ About the same", | |
| "π Both not good", | |
| "π B is better", | |
| ] | |
| async def respond(content, user, instance): | |
| """ | |
| Respond to the user in the chat interface. | |
| """ | |
| try: | |
| instance.disabled = True | |
| chat_label = instance.name | |
| if chat_model := chat_models.get(chat_label): | |
| model = chat_model | |
| else: | |
| # remove past history up to new message | |
| instance.objects = instance.objects[-1:] | |
| header_a.object = f"## Model: A" | |
| header_b.object = f"## Model: B" | |
| model = chat_models[chat_label] = random.choice(MODELS) | |
| messages = instance.serialize() | |
| messages.append({"role": "user", "content": content}) | |
| if api_key_input.value: | |
| api_key = api_key_input.value | |
| else: | |
| api_key = os.environ.get("MISTRAL_API_KEY") | |
| response = await acompletion( | |
| model=model, messages=messages, stream=True, max_tokens=128, api_key=api_key | |
| ) | |
| message = None | |
| async for chunk in response: | |
| if not chunk.choices[0].delta["content"]: | |
| continue | |
| message = instance.stream( | |
| chunk.choices[0].delta["content"], user="Assistant", message=message | |
| ) | |
| finally: | |
| instance.disabled = False | |
| async def forward_message(content, user, instance): | |
| """ | |
| Send the message to the other chat interface and respond to the user in both. | |
| """ | |
| if instance is chat_interface_a: | |
| other_instance = chat_interface_b | |
| else: | |
| other_instance = chat_interface_a | |
| other_instance.append(pn.chat.ChatMessage(content, user=user)) | |
| coroutines = [ | |
| respond(content, user, chat_interface) | |
| for chat_interface in (chat_interface_a, chat_interface_b) | |
| ] | |
| await asyncio.gather(*coroutines) | |
| def click_vote(event): | |
| """ | |
| Count the votes and update the voting results. | |
| """ | |
| if len(chat_models) == 0: | |
| return | |
| voting_label = event.obj.name | |
| if voting_label == VOTING_LABELS[0]: | |
| chat_model = chat_models[chat_interface_a.name] | |
| voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1 | |
| elif voting_label == VOTING_LABELS[3]: | |
| chat_model = chat_models[chat_interface_b.name] | |
| voting_counts[chat_model] = voting_counts.get(chat_model, 0) + 1 | |
| elif voting_label == VOTING_LABELS[1]: | |
| chat_model_a = chat_models[chat_interface_a.name] | |
| chat_model_b = chat_models[chat_interface_b.name] | |
| if chat_model_a == chat_model_b: | |
| voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1 | |
| else: | |
| voting_counts[chat_model_a] = voting_counts.get(chat_model_a, 0) + 1 | |
| voting_counts[chat_model_b] = voting_counts.get(chat_model_b, 0) + 1 | |
| header_a.object = f"## Model: {chat_models[chat_interface_a.name]}" | |
| header_b.object = f"## Model: {chat_models[chat_interface_b.name]}" | |
| for chat_label in set(chat_models.keys()): | |
| chat_models.pop(chat_label) | |
| perspective.object = ( | |
| pd.DataFrame(voting_counts, index=["Votes"]) | |
| .melt(var_name="Model", value_name="Votes") | |
| .set_index("Model") | |
| ) | |
| with sqlite3.connect("voting_counts.db") as conn: | |
| pd.DataFrame(voting_counts.items(), columns=["Model", "Votes"]).to_sql( | |
| "voting_counts", conn, if_exists="replace", index=False | |
| ) | |
| # initialize | |
| chat_models = {} | |
| with sqlite3.connect("voting_counts.db") as conn: | |
| conn.execute( | |
| "CREATE TABLE IF NOT EXISTS voting_counts (Model TEXT PRIMARY KEY, Votes INTEGER)" | |
| ) | |
| voting_counts = ( | |
| pd.read_sql("SELECT * FROM voting_counts", conn) | |
| .set_index("Model")["Votes"] | |
| .to_dict() | |
| ) | |
| # header | |
| api_key_input = pn.widgets.PasswordInput( | |
| placeholder="Mistral API Key", stylesheets=[".bk-input {color: black};"] | |
| ) | |
| # main | |
| tabs = pn.Tabs() | |
| # tab 1 | |
| chat_interface_kwargs = dict( | |
| callback=forward_message, | |
| show_undo=False, | |
| show_rerun=False, | |
| show_clear=False, | |
| show_stop=False, | |
| show_button_name=False, | |
| ) | |
| header_a = pn.pane.Markdown("## Model: A") | |
| chat_interface_a = pn.chat.ChatInterface( | |
| name="A", header=header_a, **chat_interface_kwargs | |
| ) | |
| header_b = pn.pane.Markdown("## Model: B") | |
| chat_interface_b = pn.chat.ChatInterface( | |
| name="B", header=header_b, **chat_interface_kwargs | |
| ) | |
| button_kwargs = dict(sizing_mode="stretch_width") | |
| button_row = pn.Row() | |
| for voting_label in VOTING_LABELS: | |
| button = pn.widgets.Button(name=voting_label, **button_kwargs) | |
| button.on_click(click_vote) | |
| button_row.append(button) | |
| tabs.append(("Chat", pn.Column(pn.Row(chat_interface_a, chat_interface_b), button_row))) | |
| # tab 2 | |
| perspective = pn.pane.Perspective( | |
| pd.DataFrame(voting_counts, index=["Votes"]) | |
| .melt(var_name="Model", value_name="Votes") | |
| .set_index("Model"), | |
| sizing_mode="stretch_both", | |
| editable=False, | |
| ) | |
| tabs.append(("Voting Results", perspective)) | |
| # layout | |
| pn.template.FastListTemplate( | |
| title="Mistral Chat Arena", | |
| header=[api_key_input], | |
| main=[tabs], | |
| ).servable() | |