Spaces:
Runtime error
Runtime error
from typing import List, Dict | |
class EnvironmentHistory: | |
def __init__(self, ) -> None: | |
self._history = [] | |
def add(self, label: str, value: str) -> None: | |
assert label in ['action', 'observation', 'human_edit', 'reward', 'cummulative_reward', 'terminate_state'] | |
self._history += [{ | |
'label': label, | |
'value': value, | |
}] | |
def reset(self) -> None: | |
self._history = [] | |
def __str__(self) -> str: | |
s = '' | |
for i, item in enumerate(self._history): | |
if item['label'] == 'action': | |
s += f'He takes action: {item["value"]}' | |
elif item['label'] == 'observation': | |
s += item['value'] | |
elif item['label'] == 'reward': | |
s += f'{item["value"]}' | |
elif item['label'] == 'cummulative_reward': | |
s += f'Performance: {item["value"]}' | |
# NOT CURRENTLY SUPPORTED | |
elif item['label'] == 'human_edit': | |
s += f'[human edit]: {item["value"]}' | |
elif item['label'] == 'terminate_state': | |
s += f'{item["value"]}' | |
if i != len(self._history) - 1: | |
s += '\n' | |
return s | |
def get_one_history(self) -> str: | |
s = '' | |
elements = set([ele['label'] for ele in self._history]) | |
elements.discard('cummulative_reward') | |
state_num = len(elements) | |
for i, item in enumerate(self._history[:state_num]): | |
if item['label'] == 'action': | |
s += f'He takes action: {item["value"]}' | |
elif item['label'] == 'reward': | |
s += f'{item["value"]}' | |
elif item['label'] == 'cummulative_reward': | |
s += f'Performace: {item["value"]}' | |
elif item['label'] == 'observation': | |
s += item['value'] | |
# NOT CURRENTLY SUPPORTED | |
elif item['label'] == 'human_edit': | |
s += f'[human edit]: {item["value"]}' | |
elif item['label'] == 'terminate_state': | |
s += f'{item["value"]}' | |
if i != len(self._history) - 1: | |
s += '\n' | |
return s | |
def set_history(self, num): | |
if len(self._history) > num: | |
# print(self._history,num) | |
self._history = self._history[-num:] | |
def get_last_history(self) -> str: | |
s = '' | |
for i, item in enumerate(self._history[-1:]): | |
if item['label'] == 'action': | |
s += f'He takes action: {item["value"]}' | |
elif item['label'] == 'reward': | |
s += f'{item["value"]}' | |
elif item['label'] == 'cummulative_reward': | |
s += f'Performace: {item["value"]}' | |
elif item['label'] == 'observation': | |
s += item['value'] | |
# NOT CURRENTLY SUPPORTED | |
elif item['label'] == 'human_edit': | |
s += f'[human edit]: {item["value"]}' | |
elif item['label'] == 'terminate_state': | |
s += f'{item["value"]}' | |
if i != len(self._history) - 1: | |
s += '\n' | |
return s | |
def get_histories(self,num): | |
s = '' | |
state_num = 0 | |
elements = set([ele['label'] for ele in self._history]) | |
elements.discard('cummulative_reward') | |
state_num = len(elements) | |
history_num = state_num*num+1 | |
for i, item in enumerate(self._history[-history_num:-1]): | |
if item['label'] == 'action': | |
s += f'He takes action: {item["value"]}' | |
elif item['label'] == 'reward': | |
s += f'{item["value"]}' | |
elif item['label'] == 'cummulative_reward': | |
s += f'Performace: {item["value"]}' | |
elif item['label'] == 'observation': | |
s += item['value'] | |
# NOT CURRENTLY SUPPORTED | |
elif item['label'] == 'human_edit': | |
s += f'[human edit]: {item["value"]}' | |
elif item['label'] == 'terminate_state': | |
s += f'{item["value"]}' | |
if i != len(self._history) - 1: | |
s += '\n' | |
return s | |
def get_histories_with_last(self,num): | |
s = '' | |
state_num = 0 | |
elements = set([ele['label'] for ele in self._history]) | |
elements.discard('cummulative_reward') | |
state_num = len(elements) | |
history_num = state_num*num+1 | |
for i, item in enumerate(self._history[-history_num:]): | |
if item['label'] == 'action': | |
s += f'He takes action: {item["value"]}' | |
elif item['label'] == 'reward': | |
s += f'Reward after taking action: {item["value"]}' | |
elif item['label'] == 'cummulative_reward': | |
s += f'Performace: {item["value"]}' | |
elif item['label'] == 'observation': | |
s += item['value'] | |
# NOT CURRENTLY SUPPORTED | |
elif item['label'] == 'human_edit': | |
s += f'[human edit]: {item["value"]}' | |
elif item['label'] == 'terminate_state': | |
s += f'{item["value"]}' | |
if i != len(self._history) - 1: | |
s += '\n' | |
return s | |
def remove_invalid_state(self): | |
self._history = self._history[:-1] | |
def __len__(self) -> int: | |
action = [item for item in self._history if item['label'] == 'action' ] | |
return len(action) | |