File size: 1,274 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import pytest
import numpy as np
from easydict import EasyDict
from dizoo.slime_volley.envs.slime_volley_env import SlimeVolleyEnv
@pytest.mark.envtest
class TestSlimeVolley:
@pytest.mark.parametrize('agent_vs_agent', [True, False])
def test_slime_volley(self, agent_vs_agent):
total_return = 0
env = SlimeVolleyEnv(EasyDict({'env_id': 'SlimeVolley-v0', 'agent_vs_agent': agent_vs_agent}))
# env.enable_save_replay('replay_video')
obs1 = env.reset()
print(env.observation_space)
print('observation is like:', obs1)
done = False
while not done:
action = env.random_action()
observations, rewards, done, infos = env.step(action)
if agent_vs_agent:
total_return += rewards[0]
else:
total_return += rewards
obs1, obs2 = observations[0], observations[1]
assert obs1.shape == obs2.shape, (obs1.shape, obs2.shape)
if agent_vs_agent:
agent_lives, opponent_lives = infos[0]['ale.lives'], infos[1]['ale.lives']
if agent_vs_agent:
assert agent_lives == 0 or opponent_lives == 0, (agent_lives, opponent_lives)
print("total return is:", total_return)
|