# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from transformers import PreTrainedModel from transformers.activations import ACT2FN from .configuration_downsampler import DownsamplerConfig class DownsamplerModel(PreTrainedModel): _auto_class = 'AutoModel' config_class = DownsamplerConfig base_model_prefix = 'model' supports_gradient_checkpointing = True def __init__(self, config: DownsamplerConfig) -> None: super().__init__(config) self.gradient_checkpointing = False self.group_op = nn.Conv2d( in_channels=config.visual_hidden_size, out_channels=config.llm_hidden_size, bias=config.bias, kernel_size=config.kernel_size, stride=config.stride) modules = list() for _ in range(1, config.depth): modules.append(ACT2FN[config.hidden_act]) modules.append( nn.Linear( config.llm_hidden_size, config.llm_hidden_size, bias=config.bias)) self.linear_model = nn.Sequential(*modules) def enable_input_require_grads(self): def make_inputs_require_grad(module, input, output): output.requires_grad_(True) self.model.register_forward_hook(make_inputs_require_grad) def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, DownsamplerModel): module.gradient_checkpointing = value def _forward(self, x): # (B, FULL_H, FULL_W, D) -> (B, D, FULL_H, FULL_W) x = x.permute(0, 3, 1, 2) x = self.group_op(x) # (B, D, FULL_H, FULL_W) -> (B, FULL_H, FULL_W, D) x = x.permute(0, 2, 3, 1) x = self.linear_model(x) return x def forward(self, x): if self.gradient_checkpointing and self.training: layer_outputs = torch.utils.checkpoint.checkpoint(self._forward, x) else: layer_outputs = self._forward(x) return layer_outputs