grg's picture
Cleaned old git history
be5548b
raw
history blame contribute delete
No virus
3.27 kB
from gym_minigrid.social_ai_envs.socialaiparamenv import SocialAIParamEnv
from gym_minigrid.parametric_env import *
from gym_minigrid.register import register
import inspect, importlib
# for used for automatic registration of environments
defined_classes = [name for name, _ in inspect.getmembers(importlib.import_module(__name__), inspect.isclass)]
class AppleStealingParamEnv(SocialAIParamEnv):
def __init__(self, obstacles, asocial, walk, **kwargs):
self.asocial = asocial
self.obstacles = obstacles
self.walk = walk
super(AppleStealingParamEnv, self).__init__(**kwargs)
def construct_tree(self):
tree = ParameterTree()
env_type_nd = tree.add_node("Env_type", type="param")
# Collaboration
collab_nd = tree.add_node("AppleStealing", parent=env_type_nd, type="value")
# colab_type_nd = tree.add_node("Problem", parent=collab_nd, type="param")
# tree.add_node("AppleStealing", parent=colab_type_nd, type="value")
role_nd = tree.add_node("Version", parent=collab_nd, type="param")
if self.asocial:
tree.add_node("Asocial", parent=role_nd, type="value")
else:
social_nd = tree.add_node("Social", parent=role_nd, type="value")
role_nd = tree.add_node("NPC_movement", parent=social_nd, type="param")
if self.walk:
tree.add_node("Walking", parent=role_nd, type="value")
else:
tree.add_node("Rotating", parent=role_nd, type="value")
obstacles_nd = tree.add_node("Obstacles", parent=collab_nd, type="param")
if self.obstacles not in ["No", "A_bit", "Medium", "A_lot"]:
raise ValueError("Undefined obstacle amount.")
tree.add_node(self.obstacles, parent=obstacles_nd, type="value")
return tree
# automatic registration of environments
defined_classes_ = [name for name, _ in inspect.getmembers(importlib.import_module(__name__), inspect.isclass)]
envs = list(set(defined_classes_) - set(defined_classes))
assert all([e.endswith("Env") for e in envs])
# register testing envs : cues x problems x {social, asocial} x {joint attention, no}
for asocial in [True, False]:
for obst in ["No", "A_bit", "Medium", "A_lot"]:
if asocial:
env_name = f'{"Asocial" if asocial else ""}AppleStealingObst_{obst}ParamEnv'
register(
id='SocialAI-{}-v1'.format(env_name),
entry_point='gym_minigrid.social_ai_envs:AppleStealingParamEnv',
kwargs={
'asocial': asocial,
'obstacles': obst,
'walk': False,
}
)
else:
for walk in [True, False]:
env_name = f'{"Asocial" if asocial else ""}AppleStealing{"Walk" if walk and not asocial else ""}Obst_{obst}ParamEnv'
register(
id='SocialAI-{}-v1'.format(env_name),
entry_point='gym_minigrid.social_ai_envs:AppleStealingParamEnv',
kwargs={
'asocial': asocial,
'obstacles': obst,
'walk': walk,
}
)