bvk1ng commited on
Commit
c121225
1 Parent(s): f988250

Stage-1 commit: Agent trained for 3500 episodes

Browse files
README.md CHANGED
@@ -1,8 +1,13 @@
1
  ---
2
  license: mit
3
  language:
4
- - en
5
  tags:
6
- - reinforcement learning
7
- - games
8
- ---
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  language:
4
+ - en
5
  tags:
6
+ - reinforcement learning
7
+ - games
8
+ ---
9
+
10
+ # Deep Q-Learning based Agent for Atari Breakout
11
+
12
+ The agent showcased in this space is trained using the Deep Q-Learning algorithm.
13
+ The agent was trained for $3500$ episodes with a learning rate of $0.00001$ and an epsilon value that decreased linearly over time.
atari_breakout_v0-episode-0.mp4 ADDED
Binary file (79.2 kB). View file
 
main.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main script to run the Atari Breakout-v0 game.
3
+ The DQN algorithm was used to train the agent.
4
+
5
+ @author: bvk1ng (Adityam Ghosh)
6
+ Date: 12/28/2023
7
+ """
8
+
9
+ from typing import List, Dict, Any, Callable, Tuple, Union
10
+
11
+ import numpy as np
12
+ import gymnasium as gym
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import albumentations as A
17
+ import cv2
18
+ import os
19
+ import argparse
20
+
21
+
22
+ from model import CNNModel
23
+ from utils import play_atari_game, gym
24
+ from gymnasium.wrappers.record_video import RecordVideo
25
+
26
+
27
+ K = 4
28
+ IM_SIZE = 84
29
+
30
+
31
+ class ImageTransform:
32
+ def __init__(self):
33
+ self.compose = A.Compose(
34
+ [
35
+ A.Crop(x_min=0, y_min=34, x_max=160, y_max=200, always_apply=True),
36
+ A.Resize(
37
+ height=IM_SIZE,
38
+ width=IM_SIZE,
39
+ interpolation=cv2.INTER_NEAREST,
40
+ always_apply=True,
41
+ ),
42
+ ]
43
+ )
44
+
45
+ def transform(self, img: np.ndarray) -> np.ndarray:
46
+ gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
47
+ img_tf = self.compose(image=gray_img)
48
+ return img_tf["image"]
49
+
50
+
51
+ class DQN:
52
+ def __init__(
53
+ self,
54
+ K: int,
55
+ cnn_params: List,
56
+ fully_connected_params: List,
57
+ device: str = "cuda",
58
+ load_path: str = None,
59
+ ):
60
+ self.K = K
61
+ self.cnn_model = CNNModel(
62
+ K=K,
63
+ cnn_params=cnn_params,
64
+ fully_connected_params=fully_connected_params,
65
+ ).to(device=device)
66
+ self.device = device
67
+
68
+ self.load(load_path)
69
+
70
+ def predict(self, states: np.ndarray) -> torch.Tensor:
71
+ states = np.transpose(states, (0, 3, 1, 2)) # (N, T, H, W)
72
+ states = torch.from_numpy(states).float().to(device=self.device)
73
+
74
+ states /= 255.0
75
+
76
+ return self.cnn_model(states).detach().cpu()
77
+
78
+ def load(self, path: str):
79
+ if path is not None:
80
+ self.cnn_model.load_state_dict(torch.load(path))
81
+
82
+
83
+ if __name__ == "__main__":
84
+ parser = argparse.ArgumentParser()
85
+
86
+ parser.add_argument(
87
+ "--model_folder",
88
+ "-mF",
89
+ type=str,
90
+ required=False,
91
+ default="./models",
92
+ help="the folder to store the models.",
93
+ )
94
+ parser.add_argument(
95
+ "--model_name",
96
+ "-mf",
97
+ type=str,
98
+ required=False,
99
+ default="atari_breakout_v0.pt",
100
+ help="the name of the model to save.",
101
+ )
102
+
103
+ parser.add_argument(
104
+ "--save_video",
105
+ "-s",
106
+ type=int,
107
+ required=False,
108
+ default=0,
109
+ help="whether to save a video of the gameplay or not.",
110
+ )
111
+
112
+ parser.add_argument(
113
+ "--video_folder",
114
+ "-V",
115
+ type=str,
116
+ required=False,
117
+ default="./videos",
118
+ help="where to save the video.",
119
+ )
120
+
121
+ parser.add_argument(
122
+ "--video_name",
123
+ "-v",
124
+ type=str,
125
+ required=False,
126
+ default="atari_breakout_v0",
127
+ help="the name of the video file.",
128
+ )
129
+
130
+ args = parser.parse_args()
131
+
132
+ model_folder = args.model_folder
133
+ model_name = args.model_name
134
+ save_video = args.save_video
135
+ video_folder = args.video_folder
136
+ video_name = args.video_name
137
+
138
+ cnn_params = [(32, 8, 4), (64, 4, 2), (64, 3, 1)]
139
+ fully_connected_params = [512]
140
+
141
+ load_path = None
142
+
143
+ if os.path.exists(os.path.join(model_folder, model_name)):
144
+ load_path = os.path.join(model_folder, model_name)
145
+
146
+ model = DQN(
147
+ K=K,
148
+ cnn_params=cnn_params,
149
+ fully_connected_params=fully_connected_params,
150
+ device="cuda",
151
+ lr=1e-5,
152
+ load_path=load_path,
153
+ )
154
+
155
+ img_transformer = ImageTransform()
156
+
157
+ if save_video:
158
+ env = gym.make("Breakout-v0", render_mode="rgb_array")
159
+ env = RecordVideo(env=env, video_folder=video_folder, name_prefix=video_name)
160
+
161
+ env.reset()
162
+ env.start_video_recorder()
163
+
164
+ else:
165
+ env = gym.make("Breakout-v0", render_mode="human")
166
+
167
+ play_atari_game(env=env, model=model, img_transform=img_transformer)
model.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: bvk1ng (Adityam Ghosh)
3
+ Date: 12/28/2023
4
+
5
+ """
6
+ from typing import Any, List, Tuple, Dict, Union, Callable
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class CNNModel(nn.Module):
14
+ def __init__(self, K: int, cnn_params: List, fully_connected_params: List):
15
+ super().__init__()
16
+
17
+ self.network = nn.Sequential()
18
+
19
+ for idx, (out_channels, kernel_size, stride) in enumerate(cnn_params):
20
+ self.network.add_module(
21
+ f"conv2d_{idx}",
22
+ nn.LazyConv2d(
23
+ out_channels=out_channels,
24
+ kernel_size=kernel_size,
25
+ stride=stride,
26
+ ),
27
+ )
28
+
29
+ self.network.add_module(f"activation_{idx}", nn.ReLU())
30
+
31
+ self.network.add_module("flatten", nn.Flatten())
32
+
33
+ for idx, out_feats in enumerate(fully_connected_params):
34
+ self.network.add_module(f"fc_{idx}", nn.LazyLinear(out_features=out_feats))
35
+ self.network.add_module(f"fc_activation_{idx}", nn.ReLU())
36
+
37
+ self.network.add_module("final_layer", nn.LazyLinear(out_features=K))
38
+
39
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
40
+ return self.network(X)
atari_breakout_v0.pt → models/atari_breakout_v0.pt RENAMED
File without changes
utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @author: bvk1ng (Adityam Ghosh)
3
+ Date: 12/28/2023
4
+ """
5
+
6
+ from typing import Callable, List, Tuple, Any, Dict, Union
7
+
8
+ import numpy as np
9
+ import gymnasium as gym
10
+
11
+
12
+ def update_state(state: np.ndarray, obs_small: np.ndarray) -> np.ndarray:
13
+ """Function to append the recent state into the state variable and remove the oldest using FIFO."""
14
+ return np.append(state[:, :, 1:], np.expand_dims(obs_small, axis=2), axis=2)
15
+
16
+
17
+ def play_atari_game(env: gym.Env, model: Callable, img_transform: Callable):
18
+ """Function to play the atari game."""
19
+
20
+ obs, info = env.reset()
21
+ obs_small = img_transform.transform(obs)
22
+ state = np.stack([obs_small] * 4, axis=2)
23
+
24
+ done, truncated = False, False
25
+
26
+ episode_reward = 0
27
+
28
+ while not (done or truncated):
29
+ action = model.predict(np.expand_dims(state, axis=0)).numpy()
30
+ action = np.argmax(action, axis=1)[0]
31
+ obs, reward, done, truncated, info = env.step(action)
32
+ obs_small = img_transform.transform(obs)
33
+
34
+ episode_reward += reward
35
+
36
+ next_state = update_state(state=state, obs_small=obs_small)
37
+
38
+ state = next_state
39
+
40
+ print(f"Total reward earned: {episode_reward}")