import torch import torch.nn as nn device = "cuda:0" if torch.cuda.is_available() else "cpu" class Linear(nn.Module): def __init__(self, in_features: int, out_features: int): super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter( ( torch.randn((self.in_features, self.out_features), device=device) * 0.1 ).requires_grad_() ) self.bias = nn.Parameter( (torch.randn(self.out_features, device=device) * 0.1).requires_grad_() ) def forward(self, x: torch.Tensor) -> torch.Tensor: return x @ self.weight + self.bias class ReLU(nn.Module): @staticmethod def forward(x: torch.Tensor) -> torch.Tensor: return torch.max(x, torch.tensor(0)) class Sequential(nn.Module): def __init__(self, *layers): super(Sequential, self).__init__() self.layers = nn.ModuleList(layers) def forward(self, x: torch.Tensor) -> torch.Tensor: for layer in self.layers: x = layer(x) return x class Flatten(nn.Module): @staticmethod def forward(x: torch.Tensor) -> torch.Tensor: return x.view(x.size(0), -1) class DigitClassifier(nn.Module): def __init__(self): super(DigitClassifier, self).__init__() self.main = Sequential( Flatten(), Linear(in_features=784, out_features=256), ReLU(), Linear(in_features=256, out_features=64), ReLU(), Linear(in_features=64, out_features=10), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.main(x)