Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| import torch.nn as nn | |
| import torch | |
| class BasicBlock(nn.Module): | |
| """ResNet Basic Block. | |
| Parameters | |
| ---------- | |
| in_channels : int | |
| Number of input channels | |
| out_channels : int | |
| Number of output channels | |
| stride : int, optional | |
| Convolution stride size, by default 1 | |
| identity_downsample : Optional[torch.nn.Module], optional | |
| Downsampling layer, by default None | |
| """ | |
| def __init__(self, | |
| in_channels: int, | |
| out_channels: int, | |
| stride: int = 1, | |
| identity_downsample: Optional[torch.nn.Module] = None): | |
| super(BasicBlock, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels, | |
| out_channels, | |
| kernel_size = 3, | |
| stride = stride, | |
| padding = 1) | |
| self.bn1 = nn.BatchNorm2d(out_channels) | |
| self.relu = nn.ReLU() | |
| self.conv2 = nn.Conv2d(out_channels, | |
| out_channels, | |
| kernel_size = 3, | |
| stride = 1, | |
| padding = 1) | |
| self.bn2 = nn.BatchNorm2d(out_channels) | |
| self.identity_downsample = identity_downsample | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """Apply forward computation.""" | |
| identity = x | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.conv2(x) | |
| x = self.bn2(x) | |
| # Apply an operation to the identity output. | |
| # Useful to reduce the layer size and match from conv2 output | |
| if self.identity_downsample is not None: | |
| identity = self.identity_downsample(identity) | |
| x += identity | |
| x = self.relu(x) | |
| return x | |
| class ResNet18(nn.Module): | |
| """Construct ResNet-18 Model. | |
| Parameters | |
| ---------- | |
| input_channels : int | |
| Number of input channels | |
| num_classes : int | |
| Number of class outputs | |
| """ | |
| def __init__(self, input_channels, num_classes): | |
| super(ResNet18, self).__init__() | |
| self.conv1 = nn.Conv2d(input_channels, | |
| 64, kernel_size = 7, | |
| stride = 2, padding=3) | |
| self.bn1 = nn.BatchNorm2d(64) | |
| self.relu = nn.ReLU() | |
| self.maxpool = nn.MaxPool2d(kernel_size = 3, | |
| stride = 2, | |
| padding = 1) | |
| self.layer1 = self._make_layer(64, 64, stride = 1) | |
| self.layer2 = self._make_layer(64, 128, stride = 2) | |
| self.layer3 = self._make_layer(128, 256, stride = 2) | |
| self.layer4 = self._make_layer(256, 512, stride = 2) | |
| # Last layers | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| self.fc = nn.Linear(512, num_classes) | |
| def identity_downsample(self, in_channels: int, out_channels: int) -> nn.Module: | |
| """Downsampling block to reduce the feature sizes.""" | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, | |
| out_channels, | |
| kernel_size = 3, | |
| stride = 2, | |
| padding = 1), | |
| nn.BatchNorm2d(out_channels) | |
| ) | |
| def _make_layer(self, in_channels: int, out_channels: int, stride: int) -> nn.Module: | |
| """Create sequential basic block.""" | |
| identity_downsample = None | |
| # Add downsampling function | |
| if stride != 1: | |
| identity_downsample = self.identity_downsample(in_channels, out_channels) | |
| return nn.Sequential( | |
| BasicBlock(in_channels, out_channels, identity_downsample=identity_downsample, stride=stride), | |
| BasicBlock(out_channels, out_channels) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| x = self.avgpool(x) | |
| x = x.view(x.shape[0], -1) | |
| x = self.fc(x) | |
| return x |