File size: 1,221 Bytes
c121225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
@author: bvk1ng (Adityam Ghosh)
Date: 12/28/2023

"""
from typing import Any, List, Tuple, Dict, Union, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F


class CNNModel(nn.Module):
    def __init__(self, K: int, cnn_params: List, fully_connected_params: List):
        super().__init__()

        self.network = nn.Sequential()

        for idx, (out_channels, kernel_size, stride) in enumerate(cnn_params):
            self.network.add_module(
                f"conv2d_{idx}",
                nn.LazyConv2d(
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    stride=stride,
                ),
            )

            self.network.add_module(f"activation_{idx}", nn.ReLU())

        self.network.add_module("flatten", nn.Flatten())

        for idx, out_feats in enumerate(fully_connected_params):
            self.network.add_module(f"fc_{idx}", nn.LazyLinear(out_features=out_feats))
            self.network.add_module(f"fc_activation_{idx}", nn.ReLU())

        self.network.add_module("final_layer", nn.LazyLinear(out_features=K))

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        return self.network(X)