""" @author: bvk1ng (Adityam Ghosh) Date: 12/28/2023 """ from typing import Any, List, Tuple, Dict, Union, Callable import torch import torch.nn as nn import torch.nn.functional as F class CNNModel(nn.Module): def __init__(self, K: int, cnn_params: List, fully_connected_params: List): super().__init__() self.network = nn.Sequential() for idx, (out_channels, kernel_size, stride) in enumerate(cnn_params): self.network.add_module( f"conv2d_{idx}", nn.LazyConv2d( out_channels=out_channels, kernel_size=kernel_size, stride=stride, ), ) self.network.add_module(f"activation_{idx}", nn.ReLU()) self.network.add_module("flatten", nn.Flatten()) for idx, out_feats in enumerate(fully_connected_params): self.network.add_module(f"fc_{idx}", nn.LazyLinear(out_features=out_feats)) self.network.add_module(f"fc_activation_{idx}", nn.ReLU()) self.network.add_module("final_layer", nn.LazyLinear(out_features=K)) def forward(self, X: torch.Tensor) -> torch.Tensor: return self.network(X)