|
|
|
import torch.nn as nn |
|
|
|
class ModalityProjector(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.cfg = cfg |
|
self.input_dim = cfg.vit_hidden_dim * (cfg.mp_pixel_shuffle_factor**2) |
|
self.output_dim = cfg.lm_hidden_dim |
|
self.scale_factor = cfg.mp_pixel_shuffle_factor |
|
|
|
self.proj = nn.Linear(self.input_dim, self.output_dim, bias=False) |
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, module): |
|
if isinstance(module, nn.Linear): |
|
nn.init.normal_(self.proj.weight, mean=0.0, std=0.02) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def pixel_shuffle(self, x): |
|
bsz, seq, embed_dim = x.size() |
|
seq_root = int(seq**0.5) |
|
assert seq_root**2 == seq |
|
assert seq_root % self.scale_factor == 0 |
|
|
|
height = width = seq_root |
|
x = x.view(bsz, height, width, embed_dim) |
|
h_out = height // self.scale_factor |
|
w_out = width // self.scale_factor |
|
|
|
x = x.reshape(bsz, h_out, self.scale_factor, w_out, self.scale_factor, embed_dim) |
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous() |
|
x = x.reshape(bsz, h_out * w_out, embed_dim * self.scale_factor**2) |
|
|
|
return x |
|
|
|
def forward(self, x): |
|
x = self.pixel_shuffle(x) |
|
x = self.proj(x) |
|
|
|
return x |
|
|
|
|