File size: 1,171 Bytes
c121225 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
"""
@author: bvk1ng (Adityam Ghosh)
Date: 12/28/2023
"""
from typing import Callable, List, Tuple, Any, Dict, Union
import numpy as np
import gymnasium as gym
def update_state(state: np.ndarray, obs_small: np.ndarray) -> np.ndarray:
"""Function to append the recent state into the state variable and remove the oldest using FIFO."""
return np.append(state[:, :, 1:], np.expand_dims(obs_small, axis=2), axis=2)
def play_atari_game(env: gym.Env, model: Callable, img_transform: Callable):
"""Function to play the atari game."""
obs, info = env.reset()
obs_small = img_transform.transform(obs)
state = np.stack([obs_small] * 4, axis=2)
done, truncated = False, False
episode_reward = 0
while not (done or truncated):
action = model.predict(np.expand_dims(state, axis=0)).numpy()
action = np.argmax(action, axis=1)[0]
obs, reward, done, truncated, info = env.step(action)
obs_small = img_transform.transform(obs)
episode_reward += reward
next_state = update_state(state=state, obs_small=obs_small)
state = next_state
print(f"Total reward earned: {episode_reward}")
|