File size: 2,060 Bytes
f1298e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.activations import ACT2FN

from .configuration_downsampler import DownsamplerConfig


class DownsamplerModel(PreTrainedModel):
    _auto_class = 'AutoModel'
    config_class = DownsamplerConfig
    base_model_prefix = 'model'
    supports_gradient_checkpointing = True

    def __init__(self, config: DownsamplerConfig) -> None:
        super().__init__(config)
        self.gradient_checkpointing = False

        self.group_op = nn.Conv2d(
            in_channels=config.visual_hidden_size,
            out_channels=config.llm_hidden_size,
            bias=config.bias,
            kernel_size=config.kernel_size, stride=config.stride)
        modules = list()
        for _ in range(1, config.depth):
            modules.append(ACT2FN[config.hidden_act])
            modules.append(
                nn.Linear(
                    config.llm_hidden_size,
                    config.llm_hidden_size,
                    bias=config.bias))
        self.linear_model = nn.Sequential(*modules)

    def enable_input_require_grads(self):

        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)

        self.model.register_forward_hook(make_inputs_require_grad)

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, DownsamplerModel):
            module.gradient_checkpointing = value

    def _forward(self, x):

        # (B, FULL_H, FULL_W, D) -> (B, D, FULL_H, FULL_W)
        x = x.permute(0, 3, 1, 2)
        x = self.group_op(x)
        #  (B, D, FULL_H, FULL_W) -> (B, FULL_H, FULL_W, D)
        x = x.permute(0, 2, 3, 1)
        x = self.linear_model(x)

        return x

    def forward(self, x):
        if self.gradient_checkpointing and self.training:
            layer_outputs = torch.utils.checkpoint.checkpoint(self._forward, x)
        else:
            layer_outputs = self._forward(x)
        return layer_outputs