File size: 4,587 Bytes
6a0e448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import json
import logging
from abc import ABC, abstractmethod
from enum import Enum
from functools import cached_property
from typing import Any, Literal, Optional, Self

from pydantic import BaseModel

from proxy_lite.history import ToolCall
from proxy_lite.tools import Tool, ToolExecutionResponse


class EventType(str, Enum):
    OBSERVATION = "observation"
    ACTION = "action"
    MESSAGE = "message"


class Event(BaseModel):
    type: EventType


class State(BaseModel):
    text: Optional[str] = None
    image: Optional[str] = None  # base64 encoded image
    html: Optional[str] = None
    tool_responses: Optional[list[ToolExecutionResponse]] = None


class Observation(Event):
    type: Literal[EventType.OBSERVATION] = EventType.OBSERVATION
    state: State
    terminated: bool
    reward: Optional[float] = None
    info: Optional[dict[str, Any]] = None


class Action(Event):
    type: Literal[EventType.ACTION] = EventType.ACTION
    text: Optional[str] = None
    tool_calls: Optional[list[ToolCall]] = None
    info: Optional[dict[str, Any]] = None


class BaseEnvironmentConfig(BaseModel): ...


class BaseEnvironment(BaseModel, ABC):
    config: BaseEnvironmentConfig
    logger: logging.Logger | None = None

    class Config:
        arbitrary_types_allowed = True

    async def __aenter__(self) -> Self:
        return self

    async def __aexit__(self, exc_type, exc_value, traceback):
        pass

    @property
    @abstractmethod
    def info_for_user(self) -> str: ...

    @cached_property
    @abstractmethod
    def tools(self) -> list[Tool]: ...

    @abstractmethod
    async def initialise(self) -> Observation: ...

    @abstractmethod
    async def execute_action(self, action: Action) -> Observation: ...

    @abstractmethod
    async def observe(self) -> Observation: ...

    @abstractmethod
    async def evaluate(self, **kwargs: dict[str, Any]) -> dict[str, Any]: ...

    async def execute_tool(self, tool_call: ToolCall) -> None:
        function = tool_call.function
        for tool in self.tools:
            if hasattr(tool, function["name"]):
                arguments = json.loads(function["arguments"])
                if isinstance(arguments, str):
                    arguments = json.loads(arguments)
                return await getattr(tool, function["name"])(
                    **arguments,
                )
        msg = f'No tool function with name "{function["name"]}"'
        raise ValueError(msg)

    async def get_info(self) -> dict[str, Any]:
        return {}


class Environments:
    _environment_registry: dict[str, type[BaseEnvironment]] = {}
    _environment_config_registry: dict[str, type[BaseEnvironmentConfig]] = {}

    @classmethod
    def register_environment(cls, name: str):
        """
        Decorator to register an Environment class under a given name.

        Example:
            @Environments.register_environment("my_environment")
            class MyEnvironment(BaseEnvironment):
                ...
        """

        def decorator(env_cls: type[BaseEnvironment]) -> type[BaseEnvironment]:
            cls._environment_registry[name] = env_cls
            return env_cls

        return decorator

    @classmethod
    def register_environment_config(cls, name: str):
        """
        Decorator to register an Environment configuration class under a given name.

        Example:
            @Environments.register_environment_config("my_environment")
            class MyEnvironmentConfig(BaseEnvironmentConfig):
                ...
        """

        def decorator(config_cls: type[BaseEnvironmentConfig]) -> type[BaseEnvironmentConfig]:
            cls._environment_config_registry[name] = config_cls
            return config_cls

        return decorator

    @classmethod
    def get(cls, name: str) -> type[BaseEnvironment]:
        """
        Retrieve a registered Environment class by its name.

        Raises:
            ValueError: If no such environment is found.
        """
        try:
            return cls._environment_registry[name]
        except KeyError:
            raise ValueError(f"Environment '{name}' not found.")

    @classmethod
    def get_config(cls, name: str) -> type[BaseEnvironmentConfig]:
        """
        Retrieve a registered Environment configuration class by its name.

        Raises:
            ValueError: If no such configuration is found.
        """
        try:
            return cls._environment_config_registry[name]
        except KeyError:
            raise ValueError(f"Environment config for '{name}' not found.")