|
""" |
|
@author: bvk1ng (Adityam Ghosh) |
|
Date: 12/28/2023 |
|
""" |
|
|
|
from typing import Callable, List, Tuple, Any, Dict, Union |
|
|
|
import numpy as np |
|
import gymnasium as gym |
|
|
|
|
|
def update_state(state: np.ndarray, obs_small: np.ndarray) -> np.ndarray: |
|
"""Function to append the recent state into the state variable and remove the oldest using FIFO.""" |
|
return np.append(state[:, :, 1:], np.expand_dims(obs_small, axis=2), axis=2) |
|
|
|
|
|
def play_atari_game(env: gym.Env, model: Callable, img_transform: Callable): |
|
"""Function to play the atari game.""" |
|
|
|
obs, info = env.reset() |
|
obs_small = img_transform.transform(obs) |
|
state = np.stack([obs_small] * 4, axis=2) |
|
|
|
done, truncated = False, False |
|
|
|
episode_reward = 0 |
|
|
|
while not (done or truncated): |
|
action = model.predict(np.expand_dims(state, axis=0)).numpy() |
|
action = np.argmax(action, axis=1)[0] |
|
obs, reward, done, truncated, info = env.step(action) |
|
obs_small = img_transform.transform(obs) |
|
|
|
episode_reward += reward |
|
|
|
next_state = update_state(state=state, obs_small=obs_small) |
|
|
|
state = next_state |
|
|
|
print(f"Total reward earned: {episode_reward}") |
|
|