AgentReview / agentreview /database.py
Yiqiao Jin
Initial Commit
bdafe83
raw
history blame
4.72 kB
"""
Datastore module for chat_arena.
This module provides utilities for storing the messages and the game results into database.
Currently, it supports Supabase.
"""
import json
import os
import uuid
from typing import List
from .arena import Arena
from .message import Message
# Attempt importing Supabase
try:
import supabase
# Get the Supabase URL and secret key from environment variables
SUPABASE_URL = os.environ.get("SUPABASE_URL", "")
SUPABASE_SECRET_KEY = os.environ.get("SUPABASE_SECRET_KEY", "")
assert SUPABASE_URL and SUPABASE_SECRET_KEY
except Exception:
supabase_available = False
else:
supabase_available = True
# Store the messages into the Supabase database
class SupabaseDB:
def __init__(self):
assert supabase_available and SUPABASE_URL and SUPABASE_SECRET_KEY
supabase_client = supabase.create_client(SUPABASE_URL, SUPABASE_SECRET_KEY)
self.client = supabase_client
# Save Arena state to Supabase
def save_arena(self, arena: Arena):
# Save the environment config
self._save_environment(arena)
# Save the player configs
self._save_player_configs(arena)
# Save the messages
self.save_messages(arena)
# Save the environment config of the arena
def _save_environment(self, arena: Arena):
env = arena.environment
env_config = env.to_config()
moderator_config = env_config.pop("moderator", None)
arena_row = {
"arena_id": str(arena.uuid),
"global_prompt": arena.global_prompt,
"env_type": env_config["env_type"],
"env_config": json.dumps(env_config),
}
self.client.table("Arena").insert(arena_row).execute()
# Get the moderator config
if moderator_config:
moderator_row = {
"moderator_id": str(
uuid.uuid5(arena.uuid, json.dumps(moderator_config))
),
"arena_id": str(arena.uuid),
"role_desc": moderator_config["role_desc"],
"terminal_condition": moderator_config["terminal_condition"],
"backend_type": moderator_config["backend"]["backend_type"],
"temperature": moderator_config["backend"]["temperature"],
"max_tokens": moderator_config["backend"]["max_tokens"],
}
self.client.table("Moderator").insert(moderator_row).execute()
# Save the player configs of the arena
def _save_player_configs(self, arena: Arena):
player_rows = []
for player in arena.players:
player_config = player.to_config()
player_row = {
"player_id": str(uuid.uuid5(arena.uuid, json.dumps(player_config))),
"arena_id": str(arena.uuid),
"name": player.name,
"role_desc": player_config["role_desc"],
"backend_type": player_config["backend"]["backend_type"],
"temperature": player_config["backend"].get("temperature", None),
"max_tokens": player_config["backend"].get("max_tokens", None),
}
player_rows.append(player_row)
self.client.table("Player").insert(player_rows).execute()
# Save the messages
def save_messages(self, arena: Arena, messages: List[Message] = None):
if messages is None:
messages = arena.environment.get_observation()
# Filter messages that are already logged
messages = [msg for msg in messages if not msg.logged]
message_rows = []
for message in messages:
message_row = {
"message_id": str(uuid.uuid5(arena.uuid, message.msg_hash)),
"arena_id": str(arena.uuid),
"agent_name": message.agent_name,
"content": message.content,
"turn": message.turn,
"timestamp": str(message.timestamp),
"msg_type": message.msg_type,
"visible_to": json.dumps(message.visible_to),
}
message_rows.append(message_row)
self.client.table("Message").insert(message_rows).execute()
# Mark the messages as logged
for message in messages:
message.logged = True
# Log the arena results into the Supabase database
def log_arena(arena: Arena, database=None):
if database is None:
pass
else:
database.save_arena(arena)
# Log the messages into the Supabase database
def log_messages(arena: Arena, messages: List[Message], database=None):
if database is None:
pass
else:
database.save_messages(arena, messages)