WeMM / modeling_downsampler.py
feipengma
initialize wemm
f1298e6
raw
history blame
2.06 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from .configuration_downsampler import DownsamplerConfig
class DownsamplerModel(PreTrainedModel):
_auto_class = 'AutoModel'
config_class = DownsamplerConfig
base_model_prefix = 'model'
supports_gradient_checkpointing = True
def __init__(self, config: DownsamplerConfig) -> None:
super().__init__(config)
self.gradient_checkpointing = False
self.group_op = nn.Conv2d(
in_channels=config.visual_hidden_size,
out_channels=config.llm_hidden_size,
bias=config.bias,
kernel_size=config.kernel_size, stride=config.stride)
modules = list()
for _ in range(1, config.depth):
modules.append(ACT2FN[config.hidden_act])
modules.append(
nn.Linear(
config.llm_hidden_size,
config.llm_hidden_size,
bias=config.bias))
self.linear_model = nn.Sequential(*modules)
def enable_input_require_grads(self):
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
self.model.register_forward_hook(make_inputs_require_grad)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, DownsamplerModel):
module.gradient_checkpointing = value
def _forward(self, x):
# (B, FULL_H, FULL_W, D) -> (B, D, FULL_H, FULL_W)
x = x.permute(0, 3, 1, 2)
x = self.group_op(x)
# (B, D, FULL_H, FULL_W) -> (B, FULL_H, FULL_W, D)
x = x.permute(0, 2, 3, 1)
x = self.linear_model(x)
return x
def forward(self, x):
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(self._forward, x)
else:
layer_outputs = self._forward(x)
return layer_outputs