# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # # From PyTorch: # # Copyright (c) 2016- Facebook, Inc (Adam Paszke) # Copyright (c) 2014- Facebook, Inc (Soumith Chintala) # Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) # Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) # Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) # Copyright (c) 2011-2013 NYU (Clement Farabet) # Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) # Copyright (c) 2006 Idiap Research Institute (Samy Bengio) # Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) # # From Caffe2: # # Copyright (c) 2016-present, Facebook Inc. All rights reserved. # # All contributions by Facebook: # Copyright (c) 2016 Facebook Inc. # # All contributions by Google: # Copyright (c) 2015 Google Inc. # All rights reserved. # # All contributions by Yangqing Jia: # Copyright (c) 2015 Yangqing Jia # All rights reserved. # # All contributions by Kakao Brain: # Copyright 2019-2020 Kakao Brain # # All contributions from Caffe: # Copyright(c) 2013, 2014, 2015, the respective contributors # All rights reserved. # # All other contributions: # Copyright(c) 2015, 2016 the respective contributors # All rights reserved. import torch from torchvision.models.utils import load_state_dict_from_url from typing import Type, Any, Callable, Union, List, Optional from torchvision.models.resnet import BasicBlock, Bottleneck, ResNet __all__ = [ "ResNet", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", ] model_urls = { "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", } class ResNet_mine(ResNet): def __init__(self, block, layers, classifier_run=True, **kwargs): super().__init__(block, layers, **kwargs) self.classifier_run = classifier_run def _forward_impl(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor): # See note [TorchScript super()] x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x_ = self.layer4(x) x = self.avgpool(x_) x = torch.flatten(x, 1) if self.classifier_run: x = self.fc(x) return x, x_ def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor): return self._forward_impl(x) def pnorm(weights, p): normB = torch.norm(weights, 2, 1) ws = weights.clone() for i in range(weights.size(0)): ws[i] = ws[i] / torch.pow(normB[i], p) return ws def _resnet( arch: str, block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], pretrained: bool, progress: bool, **kwargs: Any ) -> ResNet: model = ResNet_mine(block, layers, **kwargs) if pretrained: print("Inside resnet function, using ImageNet pretrained from model url!") state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model def resnext50_32x4d( pretrained: bool = False, progress: bool = True, **kwargs: Any ) -> ResNet: r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs["groups"] = 32 kwargs["width_per_group"] = 4 return _resnet( "resnext50_32x4d", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs ) def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet("resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)