| | import asyncio |
| | import logging |
| | import os |
| | import subprocess |
| | from datetime import datetime |
| | from logging import FileHandler |
| | from typing import Any, Generator, List, cast |
| |
|
| | import gin |
| | from absl import flags |
| | from rich.logging import RichHandler |
| | from sotopia.agents import LLMAgent |
| | from sotopia.database import ( |
| | AgentProfile, |
| | EnvAgentComboStorage, |
| | EnvironmentProfile, |
| | EpisodeLog, |
| | ) |
| | from sotopia.envs.evaluators import ( |
| | EvaluationForTwoAgents, |
| | ReachGoalLLMEvaluator, |
| | RuleBasedTerminatedEvaluator, |
| | SotopiaDimensions, |
| | ) |
| | from sotopia.envs.parallel import ParallelSotopiaEnv |
| | from sotopia.generation_utils.generate import LLM_Name |
| | from sotopia.messages import AgentAction, Observation |
| | from sotopia.samplers import BaseSampler, ConstraintBasedSampler, EnvAgentCombo |
| | from sotopia.server import run_async_server |
| | from sotopia_conf.gin_utils import parse_gin_flags, run |
| | from tqdm import tqdm |
| |
|
| | _DEFAULT_GIN_SEARCH_PATHS = [ |
| | os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| | ] |
| | FLAGS = flags.FLAGS |
| |
|
| | |
| | FORMAT = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" |
| |
|
| | process = subprocess.Popen( |
| | ["git", "rev-parse", "HEAD"], shell=False, stdout=subprocess.PIPE |
| | ) |
| | git_head_hash = process.communicate()[0].strip() |
| |
|
| | logging.basicConfig( |
| | level=15, |
| | format=FORMAT, |
| | datefmt="[%X]", |
| | handlers=[ |
| | RichHandler(), |
| | FileHandler( |
| | datetime.now().strftime( |
| | f"./logs/%H_%M_%d_%m_%Y_{str(git_head_hash.decode('utf-8'))}.log" |
| | ) |
| | ), |
| | ], |
| | ) |
| |
|
| | env_ids: list[str] = list(EnvironmentProfile.all_pks()) |
| | assert all( |
| | isinstance(env_id, str) for env_id in env_ids |
| | ), "env_ids should be a list of strings" |
| |
|
| |
|
| | def check_existing_episodes( |
| | env_id: str, |
| | agent_ids: list[str], |
| | models: dict[str, LLM_Name], |
| | tag: str | None = None, |
| | ) -> bool: |
| | if tag: |
| | existing_episode = EpisodeLog.find( |
| | (EpisodeLog.environment == env_id) & (EpisodeLog.tag == tag) |
| | ).all() |
| | else: |
| | existing_episode = EpisodeLog.find(EpisodeLog.environment == env_id).all() |
| | if existing_episode: |
| | for episode in existing_episode: |
| | assert isinstance(episode, EpisodeLog), "episode should be an EpisodeLog" |
| | if episode.agents == agent_ids and episode.models == list(models.values()): |
| | return True |
| | return False |
| | else: |
| | return False |
| |
|
| |
|
| | def _sample_env_agent_combo_and_push_to_db(env_id: str) -> None: |
| | sampler = ConstraintBasedSampler[Observation, AgentAction](env_candidates=[env_id]) |
| | env_agent_combo_list = list( |
| | sampler.sample(agent_classes=[LLMAgent] * 2, replacement=False) |
| | ) |
| | for env, agent in env_agent_combo_list: |
| | EnvAgentComboStorage( |
| | env_id=env.profile.pk, |
| | agent_ids=[agent[0].profile.pk, agent[1].profile.pk], |
| | ).save() |
| |
|
| |
|
| | @gin.configurable |
| | def _iterate_env_agent_combo_not_in_db( |
| | model_names: dict[str, LLM_Name], |
| | env_ids: list[str] = [], |
| | tag: str | None = None, |
| | ) -> Generator[EnvAgentCombo[Observation, AgentAction], None, None]: |
| | """We iterate over each environment and return the **first** env-agent combo that is not in the database.""" |
| | if not env_ids: |
| | env_ids = list(EnvironmentProfile.all_pks()) |
| |
|
| | all_env_agent_combo_storage_list: List[EnvAgentComboStorage] = [] |
| | for env_id in env_ids: |
| | assert env_id is not None, "env_id should not be None" |
| | env_agent_combo_storage_list = list( |
| | EnvAgentComboStorage.find(EnvAgentComboStorage.env_id == env_id).all() |
| | ) |
| | if not env_agent_combo_storage_list: |
| | _sample_env_agent_combo_and_push_to_db(env_id) |
| | env_agent_combo_storage_list = list( |
| | EnvAgentComboStorage.find(EnvAgentComboStorage.env_id == env_id).all() |
| | ) |
| | assert env_agent_combo_storage_list |
| |
|
| | |
| | env_agent_combo_storage_list = sorted( |
| | env_agent_combo_storage_list, key=lambda x: x.pk |
| | )[:1] |
| |
|
| | for env_agent_combo_storage in env_agent_combo_storage_list: |
| | env_agent_combo_storage = cast( |
| | EnvAgentComboStorage, env_agent_combo_storage |
| | ) |
| | agent_ids = env_agent_combo_storage.agent_ids |
| | if check_existing_episodes(env_id, agent_ids, model_names, tag): |
| | logging.info( |
| | f"Episode for {env_id} with agents {agent_ids} using {list(model_names.values())} already exists" |
| | ) |
| | print(f"Episode for {env_id} with agents {agent_ids} using {list(model_names.values())} already exists") |
| | continue |
| | else: |
| | all_env_agent_combo_storage_list.append(env_agent_combo_storage) |
| | print(f"Number of env agent combos to run: {len(all_env_agent_combo_storage_list)}") |
| |
|
| | for env_agent_combo_storage in all_env_agent_combo_storage_list: |
| | env_profile = EnvironmentProfile.get(env_agent_combo_storage.env_id) |
| | env = ParallelSotopiaEnv( |
| | env_profile=env_profile, |
| | model_name=model_names["env"], |
| | action_order="round-robin", |
| | evaluators=[ |
| | RuleBasedTerminatedEvaluator(max_turn_number=20, max_stale_turn=2), |
| | ], |
| | terminal_evaluators=[ |
| | ReachGoalLLMEvaluator( |
| | model_names["env"], |
| | EvaluationForTwoAgents[SotopiaDimensions], |
| | ), |
| | ], |
| | ) |
| | agent_profiles = [AgentProfile.get(id) for id in env_agent_combo_storage.agent_ids] |
| |
|
| | agents = [ |
| | LLMAgent(agent_profile=agent_profile, model_name=agent_model) |
| | for agent_profile, agent_model in zip( |
| | agent_profiles, |
| | [model_names["agent1"], model_names["agent2"]], |
| | ) |
| | ] |
| | yield env, agents |
| |
|
| |
|
| | @gin.configurable |
| | def run_async_server_in_batch( |
| | *, |
| | batch_size: int = 1, |
| | model_names: dict[str, LLM_Name] = { |
| | "env": "gpt-4", |
| | "agent1": "gpt-4o-mini", |
| | "agent2": "gpt-4o-mini", |
| | }, |
| | tag: str | None = None, |
| | verbose: bool = False, |
| | ) -> None: |
| | if not verbose: |
| | logger = logging.getLogger() |
| | logger.setLevel(logging.CRITICAL) |
| | rich_handler = logger.handlers[0] |
| | logger.removeHandler(rich_handler) |
| |
|
| | |
| | env_agent_combo_iter = _iterate_env_agent_combo_not_in_db(model_names=model_names, tag=tag) |
| | env_agent_combo_iter_length = sum(1 for _ in env_agent_combo_iter) |
| |
|
| | env_agent_combo_iter = _iterate_env_agent_combo_not_in_db(model_names=model_names, tag=tag) |
| | env_agent_combo_batch: list[EnvAgentCombo[Observation, AgentAction]] = [] |
| |
|
| | while True: |
| | for env_agent_combo in tqdm( |
| | env_agent_combo_iter, |
| | total=env_agent_combo_iter_length, |
| | desc="Running all envs in batch", |
| | ): |
| | env_agent_combo_batch.append(env_agent_combo) |
| | if len(env_agent_combo_batch) == batch_size: |
| | logging.info( |
| | f"Running batch of {batch_size} episodes: {env_agent_combo_batch}" |
| | ) |
| | asyncio.run( |
| | run_async_server( |
| | model_dict=model_names, |
| | sampler=BaseSampler[Observation, AgentAction](), |
| | env_agent_combo_list=env_agent_combo_batch, |
| | ) |
| | ) |
| | env_agent_combo_batch = [] |
| | else: |
| | if env_agent_combo_batch: |
| | logging.info( |
| | f"Running batch of {batch_size} episodes: {env_agent_combo_batch}" |
| | ) |
| | asyncio.run( |
| | run_async_server( |
| | model_dict=model_names, |
| | sampler=BaseSampler[Observation, AgentAction](), |
| | env_agent_combo_list=env_agent_combo_batch, |
| | ) |
| | ) |
| | return |
| |
|
| |
|
| | def main(_: Any) -> None: |
| | parse_gin_flags( |
| | |
| | FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, |
| | FLAGS.gin_file, |
| | FLAGS.gin_bindings, |
| | ) |
| | run_async_server_in_batch() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | flags.DEFINE_multi_string( |
| | "gin_file", |
| | default=None, |
| | help="Path to gin configuration file. Multiple paths may be passed and " |
| | "will be imported in the given order, with later configurations " |
| | "overriding earlier ones.", |
| | ) |
| |
|
| | flags.DEFINE_multi_string( |
| | "gin_bindings", default=[], help="Individual gin bindings." |
| | ) |
| |
|
| | flags.DEFINE_list( |
| | "gin_search_paths", |
| | default=["."], |
| | help="Comma-separated list of gin config path prefixes to be prepended " |
| | "to suffixes given via `--gin_file`. If a file appears in. Only the " |
| | "first prefix that produces a valid path for each suffix will be " |
| | "used.", |
| | ) |
| |
|
| | run(main) |
| |
|