File size: 1,221 Bytes
c121225 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
"""
@author: bvk1ng (Adityam Ghosh)
Date: 12/28/2023
"""
from typing import Any, List, Tuple, Dict, Union, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNNModel(nn.Module):
def __init__(self, K: int, cnn_params: List, fully_connected_params: List):
super().__init__()
self.network = nn.Sequential()
for idx, (out_channels, kernel_size, stride) in enumerate(cnn_params):
self.network.add_module(
f"conv2d_{idx}",
nn.LazyConv2d(
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
),
)
self.network.add_module(f"activation_{idx}", nn.ReLU())
self.network.add_module("flatten", nn.Flatten())
for idx, out_feats in enumerate(fully_connected_params):
self.network.add_module(f"fc_{idx}", nn.LazyLinear(out_features=out_feats))
self.network.add_module(f"fc_activation_{idx}", nn.ReLU())
self.network.add_module("final_layer", nn.LazyLinear(out_features=K))
def forward(self, X: torch.Tensor) -> torch.Tensor:
return self.network(X)
|