Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class SharedEncoder(nn.Module): | |
| """共享特征编码器""" | |
| def __init__(self, channel_list): | |
| super().__init__() | |
| c1, c2, c3, c4, c5, d1, d2 = channel_list | |
| self.relu = nn.ReLU(inplace=True) | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) | |
| self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) | |
| self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) | |
| self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) | |
| self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) | |
| self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) | |
| self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) | |
| self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) | |
| def forward(self, x): | |
| # 第一层 | |
| x = self.relu(self.conv1a(x)) | |
| conv1 = self.relu(self.conv1b(x)) | |
| x = self.pool(conv1) | |
| # 第二层 | |
| x = self.relu(self.conv2a(x)) | |
| conv2 = self.relu(self.conv2b(x)) | |
| x = self.pool(conv2) | |
| # 第三层 | |
| x = self.relu(self.conv3a(x)) | |
| conv3 = self.relu(self.conv3b(x)) | |
| x = self.pool(conv3) | |
| # 第四层 | |
| x = self.relu(self.conv4a(x)) | |
| x = self.relu(self.conv4b(x)) | |
| return x, [conv1, conv2, conv3] |