Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import models | |
| class resnet18(torch.nn.Module): | |
| def __init__(self, pretrained=True): | |
| super().__init__() | |
| self.features = models.resnet18(pretrained=pretrained) | |
| self.conv1 = self.features.conv1 | |
| self.bn1 = self.features.bn1 | |
| self.relu = self.features.relu | |
| self.maxpool1 = self.features.maxpool | |
| self.layer1 = self.features.layer1 | |
| self.layer2 = self.features.layer2 | |
| self.layer3 = self.features.layer3 | |
| self.layer4 = self.features.layer4 | |
| def forward(self, input): | |
| x = self.conv1(input) | |
| x = self.relu(self.bn1(x)) | |
| x = self.maxpool1(x) | |
| feature1 = self.layer1(x) # 1 / 4 | |
| feature2 = self.layer2(feature1) # 1 / 8 | |
| feature3 = self.layer3(feature2) # 1 / 16 | |
| feature4 = self.layer4(feature3) # 1 / 32 | |
| # global average pooling to build tail | |
| tail = torch.mean(feature4, 3, keepdim=True) | |
| tail = torch.mean(tail, 2, keepdim=True) | |
| return feature3, feature4, tail | |
| class resnet101(torch.nn.Module): | |
| def __init__(self, pretrained=True): | |
| super().__init__() | |
| self.features = models.resnet101(pretrained=pretrained) | |
| self.conv1 = self.features.conv1 | |
| self.bn1 = self.features.bn1 | |
| self.relu = self.features.relu | |
| self.maxpool1 = self.features.maxpool | |
| self.layer1 = self.features.layer1 | |
| self.layer2 = self.features.layer2 | |
| self.layer3 = self.features.layer3 | |
| self.layer4 = self.features.layer4 | |
| def forward(self, input): | |
| x = self.conv1(input) | |
| x = self.relu(self.bn1(x)) | |
| x = self.maxpool1(x) | |
| feature1 = self.layer1(x) # 1 / 4 | |
| feature2 = self.layer2(feature1) # 1 / 8 | |
| feature3 = self.layer3(feature2) # 1 / 16 | |
| feature4 = self.layer4(feature3) # 1 / 32 | |
| # global average pooling to build tail | |
| tail = torch.mean(feature4, 3, keepdim=True) | |
| tail = torch.mean(tail, 2, keepdim=True) | |
| return feature3, feature4, tail | |
| def build_contextpath(name): | |
| model = { | |
| 'resnet18': resnet18(pretrained=True), | |
| 'resnet101': resnet101(pretrained=True) | |
| } | |
| return model[name] | |