File size: 3,840 Bytes
c8588d0
 
 
 
 
 
 
 
 
 
 
8c18489
 
5d37cd3
c8588d0
 
 
 
 
d1dca7c
 
c8588d0
 
d1dca7c
 
 
 
 
 
 
 
 
c8588d0
 
d1dca7c
c8588d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1dca7c
 
 
 
 
 
 
 
 
c8588d0
 
d1dca7c
c8588d0
 
 
 
 
 
 
 
 
 
 
 
d1dca7c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.models import resnet18, resnet
from torchvision.models._meta import _IMAGENET_CATEGORIES
# from yacs.config import CfgNode

from dataclasses import dataclass
from typing import Optional, Tuple, List
from transformers.modeling_outputs import ModelOutput
from transformers import PretrainedConfig, PreTrainedModel
#from resnet_model.configuration_resnet import ResnetConfig 
from .configuration_resnet import ResnetConfig

@dataclass
class BaseModelOutputWithCls(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None



class ResnetModel(PreTrainedModel):
    """
    >>> https://huggingface.co/docs/transformers/custom_models
    >>> # 本地使用
    >>> res18_config = ResnetConfig('resnet18', True)
    >>> res18_config.save_pretrained("custom-resnet")
    >>> res18_config = ResnetConfig.from_pretrained("custom-resnet")

    >>> res18_f = ResnetModel(res18_config)
    >>> res18_f.save_pretrained("custom-resnet")
    >>> res18_f = ResnetModel.from_pretrained("custom-resnet")
    """

    config_class = ResnetConfig

    def __init__(self, config):
        super().__init__(config)
        # m = getattr(resnet, config.model_name)(config.pretrained)
        # self.model = nn.Sequential(
        #     nn.Sequential(m.conv1, m.bn1, m.relu, m.maxpool, m.layer1),  # p2
        #     m.layer2,  # p3
        #     m.layer3,  # p4
        #     m.layer4  # p5
        # )
        # c5 = m.inplanes
        self.model = getattr(resnet, config.model_name)(config.pretrained)
        self.model.fc = nn.Identity()

        c5 = self.model.inplanes
        self.output_channels = [c5 // 2, c5 // 4, c5 // 2, c5]

        out_indices = getattr(config, 'out_indices', [0, 1, 2, 3])
        self.out_indices = out_indices
        self.output_channels = [self.output_channels[i] for i in out_indices]

    # def forward(self, pixel_values, **kwargs):
    #     out = []
    #     nums = len(self.model)
    #     tensor = pixel_values
    #     for i in range(nums):
    #         tensor = self.model[i](tensor)
    #         out.append(tensor)
    #     return [out[i] for i in self.out_indices]

    def forward(self, pixel_values, **kwargs):
        out = []
        x = pixel_values
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        out.append(x)
        x = self.layer2(x)
        out.append(x)
        x = self.layer3(x)
        out.append(x)
        x = self.layer4(x)
        out.append(x)

        return [out[i] for i in self.out_indices]


class ResnetModelForImageClassification(PreTrainedModel):
    """
    >>> https://huggingface.co/docs/transformers/custom_models
    >>> # 本地使用
    >>> res18_config = ResnetConfig('resnet18', True)
    >>> res18_config.save_pretrained("custom-resnet")
    >>> res18_config = ResnetConfig.from_pretrained("custom-resnet")

    >>> res18_cls = ResnetModelForImageClassification(res18_config)
    >>> res18_cls.save_pretrained("custom-resnet")
    >>> res18_cls = ResnetModelForImageClassification.from_pretrained("custom-resnet")
    """

    config_class = ResnetConfig

    def __init__(self, config):
        super().__init__(config)
        self.model = getattr(resnet, config.model_name)(config.pretrained)
        if self.model.fc.out_features != config.num_classes:
            self.model.fc = nn.Linear(self.model.fc.in_features, config.num_classes)

    def forward(self, pixel_values, labels=None, **kwargs):
        logits = self.model(pixel_values)
        loss = torch.nn.functional.cross_entropy(logits, labels) if labels is not None else None
        # return {"loss": loss, "logits": logits}
        return BaseModelOutputWithCls(loss=loss, logits=logits)