#!/usr/bin/env python # -*- coding: utf-8 -*- """ @Time : 2024/1/4 16:32 @Author : alexanderwu @File : context.py """ import os from pathlib import Path from typing import Any, Dict, Optional from pydantic import BaseModel, ConfigDict from metagpt.config2 import Config from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import create_llm_instance from metagpt.utils.cost_manager import ( CostManager, FireworksCostManager, TokenCostManager, ) from metagpt.utils.git_repository import GitRepository from metagpt.utils.project_repo import ProjectRepo class AttrDict(BaseModel): """A dict-like object that allows access to keys as attributes, compatible with Pydantic.""" model_config = ConfigDict(extra="allow") def __init__(self, **kwargs): super().__init__(**kwargs) self.__dict__.update(kwargs) def __getattr__(self, key): return self.__dict__.get(key, None) def __setattr__(self, key, value): self.__dict__[key] = value def __delattr__(self, key): if key in self.__dict__: del self.__dict__[key] else: raise AttributeError(f"No such attribute: {key}") def set(self, key, val: Any): self.__dict__[key] = val def get(self, key, default: Any = None): return self.__dict__.get(key, default) def remove(self, key): if key in self.__dict__: self.__delattr__(key) class Context(BaseModel): """Env context for MetaGPT""" model_config = ConfigDict(arbitrary_types_allowed=True) kwargs: AttrDict = AttrDict() config: Config = Config.default() repo: Optional[ProjectRepo] = None git_repo: Optional[GitRepository] = None src_workspace: Optional[Path] = None cost_manager: CostManager = CostManager() _llm: Optional[BaseLLM] = None def new_environ(self): """Return a new os.environ object""" env = os.environ.copy() # i = self.options # env.update({k: v for k, v in i.items() if isinstance(v, str)}) return env def _select_costmanager(self, llm_config: LLMConfig) -> CostManager: """Return a CostManager instance""" if llm_config.api_type == LLMType.FIREWORKS: return FireworksCostManager() elif llm_config.api_type == LLMType.OPEN_LLM: return TokenCostManager() else: return self.cost_manager def llm(self) -> BaseLLM: """Return a LLM instance, fixme: support cache""" # if self._llm is None: self._llm = create_llm_instance(self.config.llm) if self._llm.cost_manager is None: self._llm.cost_manager = self._select_costmanager(self.config.llm) return self._llm def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM: """Return a LLM instance, fixme: support cache""" # if self._llm is None: llm = create_llm_instance(llm_config) if llm.cost_manager is None: llm.cost_manager = self._select_costmanager(llm_config) return llm def serialize(self) -> Dict[str, Any]: """Serialize the object's attributes into a dictionary. Returns: Dict[str, Any]: A dictionary containing serialized data. """ return { "workdir": str(self.repo.workdir) if self.repo else "", "kwargs": {k: v for k, v in self.kwargs.__dict__.items()}, "cost_manager": self.cost_manager.model_dump_json(), } def deserialize(self, serialized_data: Dict[str, Any]): """Deserialize the given serialized data and update the object's attributes accordingly. Args: serialized_data (Dict[str, Any]): A dictionary containing serialized data. """ if not serialized_data: return workdir = serialized_data.get("workdir") if workdir: self.git_repo = GitRepository(local_path=workdir, auto_init=True) self.repo = ProjectRepo(self.git_repo) src_workspace = self.git_repo.workdir / self.git_repo.workdir.name if src_workspace.exists(): self.src_workspace = src_workspace kwargs = serialized_data.get("kwargs") if kwargs: for k, v in kwargs.items(): self.kwargs.set(k, v) cost_manager = serialized_data.get("cost_manager") if cost_manager: self.cost_manager.model_validate_json(cost_manager)