Spaces:
Running
Running
""" | |
Datastore module for chat_arena. | |
This module provides utilities for storing the messages and the game results into database. | |
Currently, it supports Supabase. | |
""" | |
import json | |
import os | |
import uuid | |
from typing import List | |
from .arena import Arena | |
from .message import Message | |
# Attempt importing Supabase | |
try: | |
import supabase | |
# Get the Supabase URL and secret key from environment variables | |
SUPABASE_URL = os.environ.get("SUPABASE_URL", "") | |
SUPABASE_SECRET_KEY = os.environ.get("SUPABASE_SECRET_KEY", "") | |
assert SUPABASE_URL and SUPABASE_SECRET_KEY | |
except Exception: | |
supabase_available = False | |
else: | |
supabase_available = True | |
# Store the messages into the Supabase database | |
class SupabaseDB: | |
def __init__(self): | |
assert supabase_available and SUPABASE_URL and SUPABASE_SECRET_KEY | |
supabase_client = supabase.create_client(SUPABASE_URL, SUPABASE_SECRET_KEY) | |
self.client = supabase_client | |
# Save Arena state to Supabase | |
def save_arena(self, arena: Arena): | |
# Save the environment config | |
self._save_environment(arena) | |
# Save the player configs | |
self._save_player_configs(arena) | |
# Save the messages | |
self.save_messages(arena) | |
# Save the environment config of the arena | |
def _save_environment(self, arena: Arena): | |
env = arena.environment | |
env_config = env.to_config() | |
moderator_config = env_config.pop("moderator", None) | |
arena_row = { | |
"arena_id": str(arena.uuid), | |
"global_prompt": arena.global_prompt, | |
"env_type": env_config["env_type"], | |
"env_config": json.dumps(env_config), | |
} | |
self.client.table("Arena").insert(arena_row).execute() | |
# Get the moderator config | |
if moderator_config: | |
moderator_row = { | |
"moderator_id": str( | |
uuid.uuid5(arena.uuid, json.dumps(moderator_config)) | |
), | |
"arena_id": str(arena.uuid), | |
"role_desc": moderator_config["role_desc"], | |
"terminal_condition": moderator_config["terminal_condition"], | |
"backend_type": moderator_config["backend"]["backend_type"], | |
"temperature": moderator_config["backend"]["temperature"], | |
"max_tokens": moderator_config["backend"]["max_tokens"], | |
} | |
self.client.table("Moderator").insert(moderator_row).execute() | |
# Save the player configs of the arena | |
def _save_player_configs(self, arena: Arena): | |
player_rows = [] | |
for player in arena.players: | |
player_config = player.to_config() | |
player_row = { | |
"player_id": str(uuid.uuid5(arena.uuid, json.dumps(player_config))), | |
"arena_id": str(arena.uuid), | |
"name": player.name, | |
"role_desc": player_config["role_desc"], | |
"backend_type": player_config["backend"]["backend_type"], | |
"temperature": player_config["backend"].get("temperature", None), | |
"max_tokens": player_config["backend"].get("max_tokens", None), | |
} | |
player_rows.append(player_row) | |
self.client.table("Player").insert(player_rows).execute() | |
# Save the messages | |
def save_messages(self, arena: Arena, messages: List[Message] = None): | |
if messages is None: | |
messages = arena.environment.get_observation() | |
# Filter messages that are already logged | |
messages = [msg for msg in messages if not msg.logged] | |
message_rows = [] | |
for message in messages: | |
message_row = { | |
"message_id": str(uuid.uuid5(arena.uuid, message.msg_hash)), | |
"arena_id": str(arena.uuid), | |
"agent_name": message.agent_name, | |
"content": message.content, | |
"turn": message.turn, | |
"timestamp": str(message.timestamp), | |
"msg_type": message.msg_type, | |
"visible_to": json.dumps(message.visible_to), | |
} | |
message_rows.append(message_row) | |
self.client.table("Message").insert(message_rows).execute() | |
# Mark the messages as logged | |
for message in messages: | |
message.logged = True | |
# Log the arena results into the Supabase database | |
def log_arena(arena: Arena, database=None): | |
if database is None: | |
pass | |
else: | |
database.save_arena(arena) | |
# Log the messages into the Supabase database | |
def log_messages(arena: Arena, messages: List[Message], database=None): | |
if database is None: | |
pass | |
else: | |
database.save_messages(arena, messages) | |