""" @author: bvk1ng (Adityam Ghosh) Date: 12/28/2023 """ from typing import Callable, List, Tuple, Any, Dict, Union import numpy as np import gymnasium as gym def update_state(state: np.ndarray, obs_small: np.ndarray) -> np.ndarray: """Function to append the recent state into the state variable and remove the oldest using FIFO.""" return np.append(state[:, :, 1:], np.expand_dims(obs_small, axis=2), axis=2) def play_atari_game(env: gym.Env, model: Callable, img_transform: Callable): """Function to play the atari game.""" obs, info = env.reset() obs_small = img_transform.transform(obs) state = np.stack([obs_small] * 4, axis=2) done, truncated = False, False episode_reward = 0 while not (done or truncated): action = model.predict(np.expand_dims(state, axis=0)).numpy() action = np.argmax(action, axis=1)[0] obs, reward, done, truncated, info = env.step(action) obs_small = img_transform.transform(obs) episode_reward += reward next_state = update_state(state=state, obs_small=obs_small) state = next_state print(f"Total reward earned: {episode_reward}")