File size: 1,171 Bytes
c121225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
@author: bvk1ng (Adityam Ghosh)
Date: 12/28/2023
"""

from typing import Callable, List, Tuple, Any, Dict, Union

import numpy as np
import gymnasium as gym


def update_state(state: np.ndarray, obs_small: np.ndarray) -> np.ndarray:
    """Function to append the recent state into the state variable and remove the oldest using FIFO."""
    return np.append(state[:, :, 1:], np.expand_dims(obs_small, axis=2), axis=2)


def play_atari_game(env: gym.Env, model: Callable, img_transform: Callable):
    """Function to play the atari game."""

    obs, info = env.reset()
    obs_small = img_transform.transform(obs)
    state = np.stack([obs_small] * 4, axis=2)

    done, truncated = False, False

    episode_reward = 0

    while not (done or truncated):
        action = model.predict(np.expand_dims(state, axis=0)).numpy()
        action = np.argmax(action, axis=1)[0]
        obs, reward, done, truncated, info = env.step(action)
        obs_small = img_transform.transform(obs)

        episode_reward += reward

        next_state = update_state(state=state, obs_small=obs_small)

        state = next_state

    print(f"Total reward earned: {episode_reward}")