Spaces:
Runtime error
Runtime error
File size: 3,131 Bytes
2de1f98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
# # Copyright (c) OpenMMLab. All rights reserved.
# import logging
# from abc import ABCMeta, abstractmethod
#
# import torch.nn as nn
#
# from .utils import load_checkpoint
#
#
# class BaseBackbone(nn.Module, metaclass=ABCMeta):
# """Base backbone.
#
# This class defines the basic functions of a backbone. Any backbone that
# inherits this class should at least define its own `forward` function.
# """
#
# def init_weights(self, pretrained=None):
# """Init backbone weights.
#
# Args:
# pretrained (str | None): If pretrained is a string, then it
# initializes backbone weights by loading the pretrained
# checkpoint. If pretrained is None, then it follows default
# initializer or customized initializer in subclasses.
# """
# if isinstance(pretrained, str):
# logger = logging.getLogger()
# load_checkpoint(self, pretrained, strict=False, logger=logger)
# elif pretrained is None:
# # use default initializer or customized initializer in subclasses
# pass
# else:
# raise TypeError('pretrained must be a str or None.'
# f' But received {type(pretrained)}.')
#
# @abstractmethod
# def forward(self, x):
# """Forward function.
#
# Args:
# x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of
# torch.Tensor, containing input data for forward computation.
# """
# Copyright (c) OpenMMLab. All rights reserved.
import logging
from abc import ABCMeta, abstractmethod
import torch.nn as nn
from .utils import load_checkpoint
# from mmcv_custom.checkpoint import load_checkpoint
class BaseBackbone(nn.Module, metaclass=ABCMeta):
"""Base backbone.
This class defines the basic functions of a backbone. Any backbone that
inherits this class should at least define its own `forward` function.
"""
def init_weights(self, pretrained=None, patch_padding='pad'):
"""Init backbone weights.
Args:
pretrained (str | None): If pretrained is a string, then it
initializes backbone weights by loading the pretrained
checkpoint. If pretrained is None, then it follows default
initializer or customized initializer in subclasses.
"""
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger, patch_padding=patch_padding)
elif pretrained is None:
# use default initializer or customized initializer in subclasses
pass
else:
raise TypeError('pretrained must be a str or None.'
f' But received {type(pretrained)}.')
@abstractmethod
def forward(self, x):
"""Forward function.
Args:
x (Tensor | tuple[Tensor]): x could be a torch.Tensor or a tuple of
torch.Tensor, containing input data for forward computation.
""" |