MinhNH
Initial commit
48c5871
raw
history blame
1.17 kB
import numpy as np
from utils.util import *
class Evaluator(object):
def __init__(self, args, writer):
self.validate_episodes = args.validate_episodes
self.max_step = args.max_step
self.env_batch = args.env_batch
self.writer = writer
self.log = 0
def __call__(self, env, policy, debug=False):
observation = None
for episode in range(self.validate_episodes):
# reset at the start of episode
observation = env.reset(test=True, episode=episode)
episode_steps = 0
episode_reward = 0.
assert observation is not None
# start episode
episode_reward = np.zeros(self.env_batch)
while (episode_steps < self.max_step or not self.max_step):
action = policy(observation)
observation, reward, done, (step_num) = env.step(action)
episode_reward += reward
episode_steps += 1
env.save_image(self.log, episode_steps)
dist = env.get_dist()
self.log += 1
return episode_reward, dist