bvk1ng's picture
Stage-1 commit: Agent trained for 3500 episodes
c121225
"""
@author: bvk1ng (Adityam Ghosh)
Date: 12/28/2023
"""
from typing import Any, List, Tuple, Dict, Union, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNNModel(nn.Module):
def __init__(self, K: int, cnn_params: List, fully_connected_params: List):
super().__init__()
self.network = nn.Sequential()
for idx, (out_channels, kernel_size, stride) in enumerate(cnn_params):
self.network.add_module(
f"conv2d_{idx}",
nn.LazyConv2d(
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
),
)
self.network.add_module(f"activation_{idx}", nn.ReLU())
self.network.add_module("flatten", nn.Flatten())
for idx, out_feats in enumerate(fully_connected_params):
self.network.add_module(f"fc_{idx}", nn.LazyLinear(out_features=out_feats))
self.network.add_module(f"fc_activation_{idx}", nn.ReLU())
self.network.add_module("final_layer", nn.LazyLinear(out_features=K))
def forward(self, X: torch.Tensor) -> torch.Tensor:
return self.network(X)