""" Main script to run the Atari Breakout-v0 game. The DQN algorithm was used to train the agent. @author: bvk1ng (Adityam Ghosh) Date: 12/28/2023 """ from typing import List, Dict, Any, Callable, Tuple, Union import numpy as np import gymnasium as gym import torch import torch.nn as nn import torch.nn.functional as F import albumentations as A import cv2 import os import argparse from model import CNNModel from utils import play_atari_game, gym from gymnasium.wrappers.record_video import RecordVideo K = 4 IM_SIZE = 84 class ImageTransform: def __init__(self): self.compose = A.Compose( [ A.Crop(x_min=0, y_min=34, x_max=160, y_max=200, always_apply=True), A.Resize( height=IM_SIZE, width=IM_SIZE, interpolation=cv2.INTER_NEAREST, always_apply=True, ), ] ) def transform(self, img: np.ndarray) -> np.ndarray: gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) img_tf = self.compose(image=gray_img) return img_tf["image"] class DQN: def __init__( self, K: int, cnn_params: List, fully_connected_params: List, device: str = "cuda", load_path: str = None, ): self.K = K self.cnn_model = CNNModel( K=K, cnn_params=cnn_params, fully_connected_params=fully_connected_params, ).to(device=device) self.device = device self.load(load_path) def predict(self, states: np.ndarray) -> torch.Tensor: states = np.transpose(states, (0, 3, 1, 2)) # (N, T, H, W) states = torch.from_numpy(states).float().to(device=self.device) states /= 255.0 return self.cnn_model(states).detach().cpu() def load(self, path: str): if path is not None: self.cnn_model.load_state_dict(torch.load(path)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--model_folder", "-mF", type=str, required=False, default="./models", help="the folder to store the models.", ) parser.add_argument( "--model_name", "-mf", type=str, required=False, default="atari_breakout_v0.pt", help="the name of the model to save.", ) parser.add_argument( "--save_video", "-s", type=int, required=False, default=0, help="whether to save a video of the gameplay or not.", ) parser.add_argument( "--video_folder", "-V", type=str, required=False, default="./videos", help="where to save the video.", ) parser.add_argument( "--video_name", "-v", type=str, required=False, default="atari_breakout_v0", help="the name of the video file.", ) args = parser.parse_args() model_folder = args.model_folder model_name = args.model_name save_video = args.save_video video_folder = args.video_folder video_name = args.video_name cnn_params = [(32, 8, 4), (64, 4, 2), (64, 3, 1)] fully_connected_params = [512] load_path = None if os.path.exists(os.path.join(model_folder, model_name)): load_path = os.path.join(model_folder, model_name) model = DQN( K=K, cnn_params=cnn_params, fully_connected_params=fully_connected_params, device="cuda", lr=1e-5, load_path=load_path, ) img_transformer = ImageTransform() if save_video: env = gym.make("Breakout-v0", render_mode="rgb_array") env = RecordVideo(env=env, video_folder=video_folder, name_prefix=video_name) env.reset() env.start_video_recorder() else: env = gym.make("Breakout-v0", render_mode="human") play_atari_game(env=env, model=model, img_transform=img_transformer)