bvk1ng's picture
Stage-1 commit: Agent trained for 3500 episodes
c121225
raw
history blame
1.17 kB
"""
@author: bvk1ng (Adityam Ghosh)
Date: 12/28/2023
"""
from typing import Callable, List, Tuple, Any, Dict, Union
import numpy as np
import gymnasium as gym
def update_state(state: np.ndarray, obs_small: np.ndarray) -> np.ndarray:
"""Function to append the recent state into the state variable and remove the oldest using FIFO."""
return np.append(state[:, :, 1:], np.expand_dims(obs_small, axis=2), axis=2)
def play_atari_game(env: gym.Env, model: Callable, img_transform: Callable):
"""Function to play the atari game."""
obs, info = env.reset()
obs_small = img_transform.transform(obs)
state = np.stack([obs_small] * 4, axis=2)
done, truncated = False, False
episode_reward = 0
while not (done or truncated):
action = model.predict(np.expand_dims(state, axis=0)).numpy()
action = np.argmax(action, axis=1)[0]
obs, reward, done, truncated, info = env.step(action)
obs_small = img_transform.transform(obs)
episode_reward += reward
next_state = update_state(state=state, obs_small=obs_small)
state = next_state
print(f"Total reward earned: {episode_reward}")