|
""" |
|
@author: bvk1ng (Adityam Ghosh) |
|
Date: 12/28/2023 |
|
|
|
""" |
|
from typing import Any, List, Tuple, Dict, Union, Callable |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
class CNNModel(nn.Module): |
|
def __init__(self, K: int, cnn_params: List, fully_connected_params: List): |
|
super().__init__() |
|
|
|
self.network = nn.Sequential() |
|
|
|
for idx, (out_channels, kernel_size, stride) in enumerate(cnn_params): |
|
self.network.add_module( |
|
f"conv2d_{idx}", |
|
nn.LazyConv2d( |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
), |
|
) |
|
|
|
self.network.add_module(f"activation_{idx}", nn.ReLU()) |
|
|
|
self.network.add_module("flatten", nn.Flatten()) |
|
|
|
for idx, out_feats in enumerate(fully_connected_params): |
|
self.network.add_module(f"fc_{idx}", nn.LazyLinear(out_features=out_feats)) |
|
self.network.add_module(f"fc_activation_{idx}", nn.ReLU()) |
|
|
|
self.network.add_module("final_layer", nn.LazyLinear(out_features=K)) |
|
|
|
def forward(self, X: torch.Tensor) -> torch.Tensor: |
|
return self.network(X) |
|
|