File size: 3,750 Bytes
2a33798
 
 
5f98914
2a33798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f98914
2a33798
5f98914
 
 
2a33798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# This file contains functions for interacting with the CartPole environment

import gym
from gym.spaces import Discrete

class SettableStateEnv(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.env = env

    def set_state(self, state):
        self.env.state = state
        self.env.steps_beyond_terminated = None

class BaseEnv(gym.Wrapper):
    def __init__(self, env, translator):
        super().__init__(env)
        self.translator = translator
        self.env_name = super().spec.id
        self.transition_data = {}
        self.game_description = self.get_game_description()
        self.goal_description = self.get_goal_description()
        self.action_description = self.get_action_description()
        self.action_desc_dict = self.get_action_desc_dict()
        self.reward_desc_dict = self.get_reward_desc_dict()

    def reset(self, **kwargs):
        state, _ = super().reset(**kwargs)
        self.transition_data['state'] = state
        self.translator.obtain(self.transition_data)
        summary, future_summary = self.translator.translate()
        info = {
            'future_summary': future_summary
        }
        self.state = state
        return summary, info

    def step(self, action): 
        potential_next_state = self.get_potential_next_state(action)
        state, reward, terminated, _, info = super().step(action)
        self.transition_data['action'] = action
        self.transition_data['next_state'] = state
        self.transition_data['reward'] = reward 
        self.transition_data['terminated'] = terminated
        self.translator.update(self.transition_data)
        self.transition_data = {}
        self.transition_data['state'] = state
        self.translator.obtain(self.transition_data)
        summary, future_summary = self.translator.translate()
        info = {
            'future_summary': future_summary,
            'potential_state': potential_next_state
        }
        return summary, reward, terminated, _, info


    def step_llm(self, action): 
        potential_next_state = self.get_potential_next_state(action)
        if isinstance(self.action_space, Discrete):
            state, reward, terminated, _, info = super().step(action-1)
        else:
            state, reward, terminated, _, info = super().step(action)
            
        self.transition_data['action'] = action
        self.transition_data['next_state'] = state
        self.transition_data['reward'] = reward 
        self.transition_data['terminated'] = terminated
        self.translator.update(self.transition_data)
        self.transition_data = {}
        self.transition_data['state'] = state
        self.translator.obtain(self.transition_data)
        self.state = state
        summary, future_summary = self.translator.translate()
        info = {
            'future_summary': future_summary,
            'potential_state': potential_next_state,
        }
        return summary, reward, terminated, _, info

    def get_terminate_state(self, episode_len, max_episode_len):
        return self.translator.translate_terminate_state(self.state, episode_len, max_episode_len)

    def get_game_description(self,):
        return self.translator.describe_game()

    def get_goal_description(self,):
        return self.translator.describe_goal()
    
    def get_action_description(self,):
        return self.translator.describe_action()

    def get_action_desc_dict(self,):
        return self.translator.get_action_desc_dict()

    def get_reward_desc_dict(self,):
        return self.translator.get_reward_desc_dict()

    def get_potential_next_state(self, action):
        return self.translator.translate_potential_next_state(self.state, action)