File size: 3,438 Bytes
51ff9e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from pydantic import BaseModel, Field

from openhands.core.logger import openhands_logger as logger
from openhands.events.action import (
    Action,
    ChangeAgentStateAction,
    MessageAction,
    NullAction,
)
from openhands.events.event import EventSource
from openhands.events.observation import (
    AgentStateChangedObservation,
    NullObservation,
    Observation,
)
from openhands.events.serialization.event import event_to_dict
from openhands.security.invariant.nodes import Function, Message, ToolCall, ToolOutput

TraceElement = Message | ToolCall | ToolOutput | Function


def get_next_id(trace: list[TraceElement]) -> str:
    used_ids = [el.id for el in trace if isinstance(el, ToolCall)]
    for i in range(1, len(used_ids) + 2):
        if str(i) not in used_ids:
            return str(i)
    return '1'


def get_last_id(
    trace: list[TraceElement],
) -> str | None:
    for el in reversed(trace):
        if isinstance(el, ToolCall):
            return el.id
    return None


def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement]:
    next_id = get_next_id(trace)
    inv_trace: list[TraceElement] = []
    if isinstance(action, MessageAction):
        if action.source == EventSource.USER:
            inv_trace.append(Message(role='user', content=action.content))
        else:
            inv_trace.append(Message(role='assistant', content=action.content))
    elif isinstance(action, (NullAction, ChangeAgentStateAction)):
        pass
    elif hasattr(action, 'action') and action.action is not None:
        event_dict = event_to_dict(action)
        args = event_dict.get('args', {})
        thought = args.pop('thought', None)
        function = Function(name=action.action, arguments=args)
        if thought is not None:
            inv_trace.append(Message(role='assistant', content=thought))
        inv_trace.append(ToolCall(id=next_id, type='function', function=function))
    else:
        logger.error(f'Unknown action type: {type(action)}')
    return inv_trace


def parse_observation(
    trace: list[TraceElement], obs: Observation
) -> list[TraceElement]:
    last_id = get_last_id(trace)
    if isinstance(obs, (NullObservation, AgentStateChangedObservation)):
        return []
    elif hasattr(obs, 'content') and obs.content is not None:
        return [ToolOutput(role='tool', content=obs.content, tool_call_id=last_id)]
    else:
        logger.error(f'Unknown observation type: {type(obs)}')
    return []


def parse_element(
    trace: list[TraceElement], element: Action | Observation
) -> list[TraceElement]:
    if isinstance(element, Action):
        return parse_action(trace, element)
    return parse_observation(trace, element)


def parse_trace(trace: list[tuple[Action, Observation]]) -> list[TraceElement]:
    inv_trace: list[TraceElement] = []
    for action, obs in trace:
        inv_trace.extend(parse_action(inv_trace, action))
        inv_trace.extend(parse_observation(inv_trace, obs))
    return inv_trace


class InvariantState(BaseModel):
    trace: list[TraceElement] = Field(default_factory=list)

    def add_action(self, action: Action) -> None:
        self.trace.extend(parse_action(self.trace, action))

    def add_observation(self, obs: Observation) -> None:
        self.trace.extend(parse_observation(self.trace, obs))

    def concatenate(self, other: 'InvariantState') -> None:
        self.trace.extend(other.trace)