File size: 4,054 Bytes
c121225 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
"""
Main script to run the Atari Breakout-v0 game.
The DQN algorithm was used to train the agent.
@author: bvk1ng (Adityam Ghosh)
Date: 12/28/2023
"""
from typing import List, Dict, Any, Callable, Tuple, Union
import numpy as np
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import albumentations as A
import cv2
import os
import argparse
from model import CNNModel
from utils import play_atari_game, gym
from gymnasium.wrappers.record_video import RecordVideo
K = 4
IM_SIZE = 84
class ImageTransform:
def __init__(self):
self.compose = A.Compose(
[
A.Crop(x_min=0, y_min=34, x_max=160, y_max=200, always_apply=True),
A.Resize(
height=IM_SIZE,
width=IM_SIZE,
interpolation=cv2.INTER_NEAREST,
always_apply=True,
),
]
)
def transform(self, img: np.ndarray) -> np.ndarray:
gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img_tf = self.compose(image=gray_img)
return img_tf["image"]
class DQN:
def __init__(
self,
K: int,
cnn_params: List,
fully_connected_params: List,
device: str = "cuda",
load_path: str = None,
):
self.K = K
self.cnn_model = CNNModel(
K=K,
cnn_params=cnn_params,
fully_connected_params=fully_connected_params,
).to(device=device)
self.device = device
self.load(load_path)
def predict(self, states: np.ndarray) -> torch.Tensor:
states = np.transpose(states, (0, 3, 1, 2)) # (N, T, H, W)
states = torch.from_numpy(states).float().to(device=self.device)
states /= 255.0
return self.cnn_model(states).detach().cpu()
def load(self, path: str):
if path is not None:
self.cnn_model.load_state_dict(torch.load(path))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_folder",
"-mF",
type=str,
required=False,
default="./models",
help="the folder to store the models.",
)
parser.add_argument(
"--model_name",
"-mf",
type=str,
required=False,
default="atari_breakout_v0.pt",
help="the name of the model to save.",
)
parser.add_argument(
"--save_video",
"-s",
type=int,
required=False,
default=0,
help="whether to save a video of the gameplay or not.",
)
parser.add_argument(
"--video_folder",
"-V",
type=str,
required=False,
default="./videos",
help="where to save the video.",
)
parser.add_argument(
"--video_name",
"-v",
type=str,
required=False,
default="atari_breakout_v0",
help="the name of the video file.",
)
args = parser.parse_args()
model_folder = args.model_folder
model_name = args.model_name
save_video = args.save_video
video_folder = args.video_folder
video_name = args.video_name
cnn_params = [(32, 8, 4), (64, 4, 2), (64, 3, 1)]
fully_connected_params = [512]
load_path = None
if os.path.exists(os.path.join(model_folder, model_name)):
load_path = os.path.join(model_folder, model_name)
model = DQN(
K=K,
cnn_params=cnn_params,
fully_connected_params=fully_connected_params,
device="cuda",
lr=1e-5,
load_path=load_path,
)
img_transformer = ImageTransform()
if save_video:
env = gym.make("Breakout-v0", render_mode="rgb_array")
env = RecordVideo(env=env, video_folder=video_folder, name_prefix=video_name)
env.reset()
env.start_video_recorder()
else:
env = gym.make("Breakout-v0", render_mode="human")
play_atari_game(env=env, model=model, img_transform=img_transformer)
|