from swarm_policy import SwarmPolicy | |
def run_episode(env, obs, blues: int, reds: int): | |
blue_policy = SwarmPolicy(blues=blues, reds=reds, is_blue=True) | |
red_policy = SwarmPolicy(blues=blues, reds=reds, is_blue=False) | |
sum_reward = 0 | |
done = False | |
while not done: | |
action = blue_policy.predict(obs), red_policy.predict(obs) | |
obs, reward, done, info = env.step(action) | |
sum_reward += reward | |
return obs, sum_reward, done, info | |