File size: 1,656 Bytes
f2c2a4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
# Modality Projection from Vision to Language
import torch.nn as nn
class ModalityProjector(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.input_dim = cfg.vit_hidden_dim * (cfg.mp_pixel_shuffle_factor**2)
self.output_dim = cfg.lm_hidden_dim
self.scale_factor = cfg.mp_pixel_shuffle_factor
self.proj = nn.Linear(self.input_dim, self.output_dim, bias=False)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
nn.init.normal_(self.proj.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
# https://github.com/huggingface/smollm/blob/main/vision/m4/models/vllama3/modeling_vllama3.py#L1281
def pixel_shuffle(self, x):
bsz, seq, embed_dim = x.size()
seq_root = int(seq**0.5)
assert seq_root**2 == seq # Sequence length must be a perfect square for pixel shuffle
assert seq_root % self.scale_factor == 0 # Sequence root must be divisible by scale factor
height = width = seq_root
x = x.view(bsz, height, width, embed_dim)
h_out = height // self.scale_factor
w_out = width // self.scale_factor
x = x.reshape(bsz, h_out, self.scale_factor, w_out, self.scale_factor, embed_dim)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
x = x.reshape(bsz, h_out * w_out, embed_dim * self.scale_factor**2)
return x
def forward(self, x):
x = self.pixel_shuffle(x)
x = self.proj(x)
return x
|